IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

mla_fused_rope_rmsnorm_quantization

def mla_fused_rope_rmsnorm_quantization[dtype: DType, freq_dtype: DType, gamma_dtype: DType, collection_t: KVCollectionT, out_rope_dtype: DType, //, kv_input_fn: def[width: Int](IndexList[Int(2)]) capturing -> SIMD[DType.bfloat16, width]](q_rope_output: TileTensor[out_rope_dtype, Storage=q_rope_output.Storage, address_space=q_rope_output.address_space, linear_idx_type=q_rope_output.linear_idx_type, element_size=q_rope_output.element_size], q_rope: TileTensor[dtype, Storage=q_rope.Storage, address_space=q_rope.address_space, linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size], input_row_offsets: TileTensor[DType.uint32, Storage=input_row_offsets.Storage, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], freqs_cis: TileTensor[freq_dtype, Storage=freqs_cis.Storage, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], gamma: TileTensor[gamma_dtype, Storage=gamma.Storage, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], kv_collection: collection_t, layer_idx: UInt32, epsilon: Float32, ctx: DeviceContext)

Launches the fused RoPE and RMSNorm kernel for MLA attention.

This function fuses three operations:

  1. RoPE applied to query and key cache rope parts.
  2. RMSNorm applied to the non-rope portion of the key cache.
  3. Quantization of the final results before writing to the KVCache object.

Parameters:

  • ​dtype (DType): Data type of query tensors.
  • ​freq_dtype (DType): Data type of frequency cosine/sine values.
  • ​gamma_dtype (DType): Data type of RMSNorm gamma weights.
  • ​collection_t (KVCollectionT): Type of the KV cache collection.
  • ​out_rope_dtype (DType): Data type of the RoPE output values.
  • ​kv_input_fn (def[width: Int](IndexList[Int(2)]) capturing -> SIMD[DType.bfloat16, width]): Input lambda function to load the KV latent values. Shape: [tot_seq_len, cache_head_dim]. Where cache_head_dim = kv_lora_rank
    • qk_rope_head_dim.

Args: