Skip to main content

Mojo function

mla_decode_branch_fp8

mla_decode_branch_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_nope: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_rope: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, w_uk: LayoutTensor[fp8_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], w_uk_scale: LayoutTensor[fp8_scale_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], w_uv: LayoutTensor[fp8_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], w_uv_scale: LayoutTensor[fp8_scale_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)

This is a manually fused kernel that performs the following operations: - Project q_nope to kv_latent_dim through a fp8 batched matmul: q_nope_proj = q_nope_t @ w_uk. - Concatenate q_nope_proj and q_rope: q_full = concat(q_nope_proj, q_rope, axis=2). - Perform MLA decode. - Project raw_output to v_head_dim through another fp8 batched matmul: output = raw_output_t @ w_uv.

Parameters:

  • dtype (DType): Data type of the input and output tensors.
  • fp8_dtype (DType): Data type of the fp8 input and output tensors.
  • fp8_scale_dtype (DType): Data type of the fp8 scale input and output tensors.
  • collection_t (KVCollectionT): Type of the KV collection.
  • m_scale_granularity (Int): Granularity of the scale for M dimension of the matrix multiplication.
  • n_scale_granularity (Int): Granularity of the scale for N dimension of the matrix multiplication.
  • k_scale_granularity (Int): Granularity of the scale for K dimension of the matrix multiplication.
  • mask_str (StringSlice): Mask variant.
  • score_mod_str (StringSlice): Positional encoding variant.
  • target (StringSlice): Target device.

Args:

  • output (LayoutTensor): Output tensor of shape [tot_seq_len, num_heads, v_head_dim].
  • q_nope (LayoutTensor): Query tensor of shape [tot_seq_len, num_heads, qk_nope_head_dim].
  • q_rope (LayoutTensor): Rope query tensor of shape [tot_seq_len, num_heads, qk_rope_head_dim].
  • input_row_offsets (LayoutTensor): Indicates where each request starts and ends in q. Shape: [num_batches + 1].
  • kv_collection (collection_t): Paged KV Cache object.
  • layer_idx (UInt32): Layer index.
  • scale (Float32): Scale for the attention calculation.
  • w_uk (LayoutTensor): Weight matrix for projecting the non-rope part of each query head to KV latent space. Shape: [num_heads, kv_latent_dim, qk_nope_head_dim].
  • w_uk_scale (LayoutTensor): The scale for the w_uk weight matrix. Shape varies depending on the float8_config.
  • w_uv (LayoutTensor): Weight matrix for projecting the output of the attention back to each head's original space. Shape: [num_heads, v_head_dim, kv_latent_dim].
  • w_uv_scale (LayoutTensor): The scale for the w_uv weight matrix. Shape varies depending on the float8_config.
  • ctx (DeviceContext): Device context.

Was this page helpful?