For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
RMSNormFusedResidual
struct RMSNormFusedResidual
RMS normalization with fused residual connection for Mamba blocks.
Performs RMS normalization on (input + residual), returning both the normalized output and the pre-normalized input (residual output). This matches the fused residual + norm pattern used in Mamba models.
Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layer_norm.py
Tensor Shapes: - output: (..., hidden_size) - Normalized output tensor (same shape as input). - residual_output: (..., hidden_size) - Pre-normalized input (input + residual). - input: (..., hidden_size) - Input tensor to normalize. - residual_input: (..., hidden_size) - Residual tensor to add before normalization. - weight: (hidden_size,) - Weight tensor (gamma) for normalization. - eps: Scalar - Epsilon value for numerical stability (default: 1e-6). - weight_offset: Scalar - Offset added to weight before normalization (default: 0.0). - dropout_p: Scalar - Dropout probability (default: 0.0). - seed: Scalar[uint64] - Random seed for dropout (default: 0).
Compile-time Options: - multiply_before_cast: If True, multiplies by weight before casting to output dtype. If False, casts to output dtype before multiplying by weight.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static execute[dtype: DType, rank: Int, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool = True](output: ManagedTensorSlice[Output, static_spec=output.static_spec], residual_output: ManagedTensorSlice[Output, static_spec=residual_output.static_spec], input: ManagedTensorSlice[Input, static_spec=input.static_spec], residual_input: ManagedTensorSlice[Input, static_spec=residual_input.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], dropout_p: Scalar[dtype], seed: UInt64, ctx: DeviceContext)
shapeβ
static shape[dtype: DType, rank: Int](input: ManagedTensorSlice[Input, static_spec=input.static_spec], residual_input: ManagedTensorSlice[Input, static_spec=residual_input.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], dropout_p: Scalar[dtype], seed: UInt64) -> IndexList[rank]
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!