Mojo function
mla_fused_rope_rmsnorm
mla_fused_rope_rmsnorm[dtype: DType, freq_dtype: DType, gamma_dtype: DType, collection_t: KVCollectionT, //](q_rope_output: TileTensor[dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], q_rope: TileTensor[dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_row_offsets: TileTensor[DType.uint32, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], freqs_cis: TileTensor[freq_dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], gamma: TileTensor[gamma_dtype, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_collection: collection_t, layer_idx: UInt32, epsilon: Float32, ctx: DeviceContext)
Launches the fused RoPE and RMSNorm kernel for MLA attention.
This function fuses two operations:
- RoPE applied to query and key cache rope parts.
- RMSNorm applied to the non-rope portion of the key cache.
Parameters:
- dtype (
DType): Data type of query tensors. - freq_dtype (
DType): Data type of frequency cosine/sine values. - gamma_dtype (
DType): Data type of RMSNorm gamma weights. - collection_t (
KVCollectionT): Type of the KV cache collection.
Args:
- q_rope_output (
TileTensor): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim]. - q_rope (
TileTensor): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim]. - input_row_offsets (
TileTensor): Row offsets indicating request boundaries. Shape: [num_batches + 1]. - freqs_cis (
TileTensor): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim]. - gamma (
TileTensor): RMSNorm gamma weights. Shape: [kv_norm_dim]. - kv_collection (
collection_t): Paged KV cache collection. - layer_idx (
UInt32): Index of the current transformer layer. - epsilon (
Float32): Small constant for numerical stability in RMSNorm. - ctx (
DeviceContext): Device context for kernel execution.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!