IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

mla_decode_branch_fp8

def mla_decode_branch_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], kv_input_fn: def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width], target: StringSlice[StaticConstantOrigin] = StringSlice("cpu"), sparse_mla: Bool = False](output: TileTensor[dtype, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size], input_row_offsets: TileTensor[DType.uint32, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], freqs_cis: TileTensor[linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], kv_norm_gamma: TileTensor[linear_idx_type=kv_norm_gamma.linear_idx_type, element_size=kv_norm_gamma.element_size], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, w_uk: TileTensor[fp8_dtype, linear_idx_type=w_uk.linear_idx_type, element_size=w_uk.element_size], w_uk_scale: TileTensor[fp8_scale_dtype, linear_idx_type=w_uk_scale.linear_idx_type, element_size=w_uk_scale.element_size], w_uv: TileTensor[fp8_dtype, linear_idx_type=w_uv.linear_idx_type, element_size=w_uv.element_size], w_uv_scale: TileTensor[fp8_scale_dtype, linear_idx_type=w_uv_scale.linear_idx_type, element_size=w_uv_scale.element_size], scalar_args_buf: TileTensor[DType.int64, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], ctx: DeviceContext, d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, indices_stride: Int = 0, topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, extra_k: OptionalReg[collection_t.CacheType] = None, extra_d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_indices_stride: Int = 0, extra_topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_scales_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, num_partitions_in: Optional[Int] = None)

This is a manually fused kernel that performs the following operations: - Apply RoPE to the query and the key cache (in-place). - Apply RMSNorm to the non-rope portion of the key cache (in-place). - Project q_nope to kv_latent_dim through a fp8 batched matmul: q_nope_proj = q_nope_t @ w_uk. - Concatenate q_nope_proj and q_rope: q_full = concat(q_nope_proj, q_rope, axis=2). - Perform MLA decode. - Project raw_output to v_head_dim through another fp8 batched matmul: output = raw_output_t @ w_uv.

Parameters:

  • ​dtype (DType): Data type of the input and output tensors.
  • ​fp8_dtype (DType): Data type of the fp8 input and output tensors.
  • ​fp8_scale_dtype (DType): Data type of the fp8 scale input and output tensors.
  • ​collection_t (KVCollectionT): Type of the KV collection.
  • ​m_scale_granularity (Int): Granularity of the scale for M dimension of the matrix multiplication.
  • ​n_scale_granularity (Int): Granularity of the scale for N dimension of the matrix multiplication.
  • ​k_scale_granularity (Int): Granularity of the scale for K dimension of the matrix multiplication.
  • ​mask_str (StringSlice[StaticConstantOrigin]): Mask variant.
  • ​kv_input_fn (def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]): Input lambda function to load the KV latent values. Shape: [tot_seq_len, cache_head_dim]. Where cache_head_dim = kv_lora_rank
    • qk_rope_head_dim.
  • ​target (StringSlice[StaticConstantOrigin]): Target device.
  • ​sparse_mla (Bool): Whether to use sparse MLA.

Args: