Mojo function
mla_decode_branch_fp8
mla_decode_branch_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], kv_input_fn: def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width], target: StringSlice[StaticConstantOrigin] = StringSlice("cpu"), sparse_mla: Bool = False](output: TileTensor[dtype, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size], input_row_offsets: TileTensor[DType.uint32, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], freqs_cis: TileTensor[linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], kv_norm_gamma: TileTensor[linear_idx_type=kv_norm_gamma.linear_idx_type, element_size=kv_norm_gamma.element_size], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, w_uk: TileTensor[fp8_dtype, linear_idx_type=w_uk.linear_idx_type, element_size=w_uk.element_size], w_uk_scale: TileTensor[fp8_scale_dtype, linear_idx_type=w_uk_scale.linear_idx_type, element_size=w_uk_scale.element_size], w_uv: TileTensor[fp8_dtype, linear_idx_type=w_uv.linear_idx_type, element_size=w_uv.element_size], w_uv_scale: TileTensor[fp8_scale_dtype, linear_idx_type=w_uv_scale.linear_idx_type, element_size=w_uv_scale.element_size], scalar_args_buf: TileTensor[DType.int64, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], ctx: DeviceContext, d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, indices_stride: Int = 0, topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, extra_k: OptionalReg[collection_t.CacheType] = None, extra_d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_indices_stride: Int = 0, extra_topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_scales_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None)
This is a manually fused kernel that performs the following operations: - Apply RoPE to the query and the key cache (in-place). - Apply RMSNorm to the non-rope portion of the key cache (in-place). - Project q_nope to kv_latent_dim through a fp8 batched matmul: q_nope_proj = q_nope_t @ w_uk. - Concatenate q_nope_proj and q_rope: q_full = concat(q_nope_proj, q_rope, axis=2). - Perform MLA decode. - Project raw_output to v_head_dim through another fp8 batched matmul: output = raw_output_t @ w_uv.
Parameters:
- βdtype (
DType): Data type of the input and output tensors. - βfp8_dtype (
DType): Data type of the fp8 input and output tensors. - βfp8_scale_dtype (
DType): Data type of the fp8 scale input and output tensors. - βcollection_t (
KVCollectionT): Type of the KV collection. - βm_scale_granularity (
Int): Granularity of the scale for M dimension of the matrix multiplication. - βn_scale_granularity (
Int): Granularity of the scale for N dimension of the matrix multiplication. - βk_scale_granularity (
Int): Granularity of the scale for K dimension of the matrix multiplication. - βmask_str (
StringSlice[StaticConstantOrigin]): Mask variant. - βkv_input_fn (
def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]): Input lambda function to load the KV latent values. Shape: [tot_seq_len, cache_head_dim]. Where cache_head_dim = kv_lora_rank- qk_rope_head_dim.
- βtarget (
StringSlice[StaticConstantOrigin]): Target device. - βsparse_mla (
Bool): Whether to use sparse MLA.
Args:
- βoutput (
TileTensor[dtype, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tensor of shape [tot_seq_len, num_heads, v_head_dim]. - βq (
TileTensor[dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size]): Combined query tensor containing both nope and rope parts. Shape: [tot_seq_len, num_heads, qk_nope_head_dim + qk_rope_head_dim]. - βinput_row_offsets (
TileTensor[DType.uint32, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size]): Indicates where each request starts and ends inq. Shape: [num_batches + 1]. - βfreqs_cis (
TileTensor[linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size]): Precomputed RoPE frequency values for rotary position embeddings. Shape: [max_seq_len, qk_rope_head_dim]. - βkv_norm_gamma (
TileTensor[linear_idx_type=kv_norm_gamma.linear_idx_type, element_size=kv_norm_gamma.element_size]): RMSNorm gamma weights for normalizing the KV cache. Shape: [kv_lora_rank]. - βkv_collection (
collection_t): Paged KV Cache object. - βlayer_idx (
UInt32): Layer index. - βscale (
Float32): Scale for the attention calculation. - βepsilon (
Float32): Small constant for numerical stability in RMSNorm. - βw_uk (
TileTensor[fp8_dtype, linear_idx_type=w_uk.linear_idx_type, element_size=w_uk.element_size]): Weight matrix for projecting the non-rope part of each query head to KV latent space. Shape: [num_heads, kv_latent_dim, qk_nope_head_dim]. - βw_uk_scale (
TileTensor[fp8_scale_dtype, linear_idx_type=w_uk_scale.linear_idx_type, element_size=w_uk_scale.element_size]): The scale for the w_uk weight matrix. Shape varies depending on the float8_config. - βw_uv (
TileTensor[fp8_dtype, linear_idx_type=w_uv.linear_idx_type, element_size=w_uv.element_size]): Weight matrix for projecting the output of the attention back to each head's original space. Shape: [num_heads, v_head_dim, kv_latent_dim]. - βw_uv_scale (
TileTensor[fp8_scale_dtype, linear_idx_type=w_uv_scale.linear_idx_type, element_size=w_uv_scale.element_size]): The scale for the w_uv weight matrix. Shape varies depending on the float8_config. - βscalar_args_buf (
TileTensor[DType.int64, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size]): Packed MLA dispatch metadata buffer. - βctx (
DeviceContext): Device context. - βd_indices (
OptionalReg[UnsafePointer[Int32, MutAnyOrigin]]): Sparse decode packed indices (null when dense). - βindices_stride (
Int): Row stride ind_indices. - βtopk_lengths (
OptionalReg[UnsafePointer[Int32, MutAnyOrigin]]): Per-batch valid top-k counts. - βattn_sink_ptr (
OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]): Optional per-batch attention sink weights. - βextra_k (
OptionalReg[collection_t.CacheType]): Optional second key cache operand (seeflare_mla_decoding). - βextra_d_indices (
OptionalReg[UnsafePointer[Int32, MutAnyOrigin]]): Extra-stream sparse indices. - βextra_indices_stride (
Int): Stride forextra_d_indices. - βextra_topk_lengths (
OptionalReg[UnsafePointer[Int32, MutAnyOrigin]]): Extra-stream per-batch lengths. - βextra_scales_ptr (
OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]): Extra-stream scales.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!