For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MatmulFusedPartialRMSNorm
struct MatmulFusedPartialRMSNorm
Fuses GEMV (M=1 matmul) with partial RMS normalization.
Computes y = x @ W.T, then applies RMS normalization to the first N_normed columns while passing the remaining columns through unchanged.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static def execute[dtype: DType, rank: Int, target: StringSlice[StaticConstantOrigin], transpose_b: Bool = True](normed_output: ManagedTensorSlice[Output, static_spec=normed_output.static_spec], unnormed_output: ManagedTensorSlice[Output, static_spec=unnormed_output.static_spec], input: ManagedTensorSlice[Input, static_spec=input.static_spec], weight: ManagedTensorSlice[Input, static_spec=weight.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], ctx: DeviceContext)
Execute fused GEMV + partial RMS norm.
Calls gemv_and_partial_norm from nn.gemv_partial_norm which
computes y = x @ W.T, then partitions y into normed and unnormed
outputs.
shapeβ
static def shape[dtype: DType, rank: Int](input: ManagedTensorSlice[Input, static_spec=input.static_spec], weight: ManagedTensorSlice[Input, static_spec=weight.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype]) -> IndexList[rank]
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!