Skip to main content

Mojo function

mla_fused_rope_rmsnorm

mla_fused_rope_rmsnorm[dtype: DType, freq_dtype: DType, gamma_dtype: DType, collection_t: KVCollectionT, //](q_rope_output: TileTensor[dtype, q_rope_output.LayoutType, q_rope_output.origin, address_space=q_rope_output.address_space, linear_idx_type=q_rope_output.linear_idx_type, element_shape_types=q_rope_output.element_shape_types], q_rope: TileTensor[dtype, q_rope.LayoutType, q_rope.origin, address_space=q_rope.address_space, linear_idx_type=q_rope.linear_idx_type, element_shape_types=q_rope.element_shape_types], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_shape_types=input_row_offsets.element_shape_types], freqs_cis: TileTensor[freq_dtype, freqs_cis.LayoutType, freqs_cis.origin, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_shape_types=freqs_cis.element_shape_types], gamma: TileTensor[gamma_dtype, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_shape_types=gamma.element_shape_types], kv_collection: collection_t, layer_idx: UInt32, epsilon: Float32, ctx: DeviceContext)

Launches the fused RoPE and RMSNorm kernel for MLA attention.

This function fuses two operations:

  1. RoPE applied to query and key cache rope parts.
  2. RMSNorm applied to the non-rope portion of the key cache.

Parameters:

  • dtype (DType): Data type of query tensors.
  • freq_dtype (DType): Data type of frequency cosine/sine values.
  • gamma_dtype (DType): Data type of RMSNorm gamma weights.
  • collection_t (KVCollectionT): Type of the KV cache collection.

Args:

  • q_rope_output (TileTensor): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • q_rope (TileTensor): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • input_row_offsets (TileTensor): Row offsets indicating request boundaries. Shape: [num_batches + 1].
  • freqs_cis (TileTensor): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim].
  • gamma (TileTensor): RMSNorm gamma weights. Shape: [kv_norm_dim].
  • kv_collection (collection_t): Paged KV cache collection.
  • layer_idx (UInt32): Index of the current transformer layer.
  • epsilon (Float32): Small constant for numerical stability in RMSNorm.
  • ctx (DeviceContext): Device context for kernel execution.

Was this page helpful?