Mojo function
mla_fused_rope_rmsnorm_quantization
mla_fused_rope_rmsnorm_quantization[dtype: DType, freq_dtype: DType, gamma_dtype: DType, collection_t: KVCollectionT, out_rope_dtype: DType, //, kv_input_fn: def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]](q_rope_output: TileTensor[out_rope_dtype, address_space=q_rope_output.address_space, linear_idx_type=q_rope_output.linear_idx_type, element_size=q_rope_output.element_size], q_rope: TileTensor[dtype, address_space=q_rope.address_space, linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size], input_row_offsets: TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], freqs_cis: TileTensor[freq_dtype, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], gamma: TileTensor[gamma_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], kv_collection: collection_t, layer_idx: UInt32, epsilon: Float32, ctx: DeviceContext)
Launches the fused RoPE and RMSNorm kernel for MLA attention.
This function fuses three operations:
- RoPE applied to query and key cache rope parts.
- RMSNorm applied to the non-rope portion of the key cache.
- Quantization of the final results before writing to the KVCache object.
Parameters:
- βdtype (
DType): Data type of query tensors. - βfreq_dtype (
DType): Data type of frequency cosine/sine values. - βgamma_dtype (
DType): Data type of RMSNorm gamma weights. - βcollection_t (
KVCollectionT): Type of the KV cache collection. - βout_rope_dtype (
DType): Data type of the RoPE output values. - βkv_input_fn (
def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]): Input lambda function to load the KV latent values. Shape: [tot_seq_len, cache_head_dim]. Where cache_head_dim = kv_lora_rank- qk_rope_head_dim.
Args:
- βq_rope_output (
TileTensor[out_rope_dtype, address_space=q_rope_output.address_space, linear_idx_type=q_rope_output.linear_idx_type, element_size=q_rope_output.element_size]): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim]. - βq_rope (
TileTensor[dtype, address_space=q_rope.address_space, linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size]): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim]. - βinput_row_offsets (
TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size]): Row offsets indicating request boundaries. Shape: [num_batches + 1]. - βfreqs_cis (
TileTensor[freq_dtype, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size]): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim]. - βgamma (
TileTensor[gamma_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size]): RMSNorm gamma weights. Shape: [kv_norm_dim]. - βkv_collection (
collection_t): Paged KV cache collection. - βlayer_idx (
UInt32): Index of the current transformer layer. - βepsilon (
Float32): Small constant for numerical stability in RMSNorm. - βctx (
DeviceContext): Device context for kernel execution.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!