Skip to main content

Mojo function

mla_fused_rope_rmsnorm_quantization

mla_fused_rope_rmsnorm_quantization[dtype: DType, freq_dtype: DType, gamma_dtype: DType, collection_t: KVCollectionT, out_rope_dtype: DType, //, kv_input_fn: def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]](q_rope_output: TileTensor[out_rope_dtype, address_space=q_rope_output.address_space, linear_idx_type=q_rope_output.linear_idx_type, element_size=q_rope_output.element_size], q_rope: TileTensor[dtype, address_space=q_rope.address_space, linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size], input_row_offsets: TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], freqs_cis: TileTensor[freq_dtype, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], gamma: TileTensor[gamma_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], kv_collection: collection_t, layer_idx: UInt32, epsilon: Float32, ctx: DeviceContext)

Launches the fused RoPE and RMSNorm kernel for MLA attention.

This function fuses three operations:

  1. RoPE applied to query and key cache rope parts.
  2. RMSNorm applied to the non-rope portion of the key cache.
  3. Quantization of the final results before writing to the KVCache object.

Parameters:

  • ​dtype (DType): Data type of query tensors.
  • ​freq_dtype (DType): Data type of frequency cosine/sine values.
  • ​gamma_dtype (DType): Data type of RMSNorm gamma weights.
  • ​collection_t (KVCollectionT): Type of the KV cache collection.
  • ​out_rope_dtype (DType): Data type of the RoPE output values.
  • ​kv_input_fn (def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]): Input lambda function to load the KV latent values. Shape: [tot_seq_len, cache_head_dim]. Where cache_head_dim = kv_lora_rank
    • qk_rope_head_dim.

Args: