IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MatmulFusedPartialRMSNorm

struct MatmulFusedPartialRMSNorm

Fuses GEMV (M=1 matmul) with partial RMS normalization.

Computes y = x @ W.T, then applies RMS normalization to the first N_normed columns while passing the remaining columns through unchanged.

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static def execute[dtype: DType, rank: Int, target: StringSlice[StaticConstantOrigin], transpose_b: Bool = True](normed_output: ManagedTensorSlice[Output, static_spec=normed_output.static_spec], unnormed_output: ManagedTensorSlice[Output, static_spec=unnormed_output.static_spec], input: ManagedTensorSlice[Input, static_spec=input.static_spec], weight: ManagedTensorSlice[Input, static_spec=weight.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], ctx: DeviceContext)

Execute fused GEMV + partial RMS norm.

Calls gemv_and_partial_norm from nn.gemv_partial_norm which computes y = x @ W.T, then partitions y into normed and unnormed outputs.

shape​

static def shape[dtype: DType, rank: Int](input: ManagedTensorSlice[Input, static_spec=input.static_spec], weight: ManagedTensorSlice[Input, static_spec=weight.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype]) -> IndexList[rank]

Returns:

IndexList[rank]