Skip to main content

Mojo function

mla_fused_rope_rmsnorm

mla_fused_rope_rmsnorm[dtype: DType, freq_dtype: DType, gamma_dtype: DType, collection_t: KVCollectionT, //](q_rope_output: TileTensor[dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], q_rope: TileTensor[dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_row_offsets: TileTensor[DType.uint32, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], freqs_cis: TileTensor[freq_dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], gamma: TileTensor[gamma_dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_collection: collection_t, layer_idx: UInt32, epsilon: Float32, ctx: DeviceContext)

Launches the fused RoPE and RMSNorm kernel for MLA attention.

This function fuses two operations:

  1. RoPE applied to query and key cache rope parts.
  2. RMSNorm applied to the non-rope portion of the key cache.

Parameters:

  • dtype (DType): Data type of query tensors.
  • freq_dtype (DType): Data type of frequency cosine/sine values.
  • gamma_dtype (DType): Data type of RMSNorm gamma weights.
  • collection_t (KVCollectionT): Type of the KV cache collection.

Args:

  • q_rope_output (TileTensor): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • q_rope (TileTensor): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • input_row_offsets (TileTensor): Row offsets indicating request boundaries. Shape: [num_batches + 1].
  • freqs_cis (TileTensor): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim].
  • gamma (TileTensor): RMSNorm gamma weights. Shape: [kv_norm_dim].
  • kv_collection (collection_t): Paged KV cache collection.
  • layer_idx (UInt32): Index of the current transformer layer.
  • epsilon (Float32): Small constant for numerical stability in RMSNorm.
  • ctx (DeviceContext): Device context for kernel execution.

Was this page helpful?