IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

RMSNormFusedResidual

struct RMSNormFusedResidual

RMS normalization with fused residual connection for Mamba blocks.

Performs RMS normalization on (input + residual), returning both the normalized output and the pre-normalized input (residual output). This matches the fused residual + norm pattern used in Mamba models.

Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layer_norm.py

Tensor Shapes: - output: (..., hidden_size) - Normalized output tensor (same shape as input). - residual_output: (..., hidden_size) - Pre-normalized input (input + residual). - input: (..., hidden_size) - Input tensor to normalize. - residual_input: (..., hidden_size) - Residual tensor to add before normalization. - weight: (hidden_size,) - Weight tensor (gamma) for normalization. - eps: Scalar - Epsilon value for numerical stability (default: 1e-6). - weight_offset: Scalar - Offset added to weight before normalization (default: 0.0). - dropout_p: Scalar - Dropout probability (default: 0.0). - seed: Scalar[uint64] - Random seed for dropout (default: 0).

Compile-time Options: - multiply_before_cast: If True, multiplies by weight before casting to output dtype. If False, casts to output dtype before multiplying by weight.

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static execute[dtype: DType, rank: Int, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool = True](output: ManagedTensorSlice[Output, static_spec=output.static_spec], residual_output: ManagedTensorSlice[Output, static_spec=residual_output.static_spec], input: ManagedTensorSlice[Input, static_spec=input.static_spec], residual_input: ManagedTensorSlice[Input, static_spec=residual_input.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], dropout_p: Scalar[dtype], seed: UInt64, ctx: DeviceContext)

shape​

static shape[dtype: DType, rank: Int](input: ManagedTensorSlice[Input, static_spec=input.static_spec], residual_input: ManagedTensorSlice[Input, static_spec=residual_input.static_spec], gamma: ManagedTensorSlice[Input, static_spec=gamma.static_spec], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], dropout_p: Scalar[dtype], seed: UInt64) -> IndexList[rank]

Returns:

IndexList[rank]