Skip to main content

Mojo function

mla_prefill_branch_fp8

mla_prefill_branch_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], q: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], freqs_cis: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_norm_gamma: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, buffer_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], cache_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], buffer_length: Int, kv_b_proj: TileTensor[fp8_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_b_proj_scale: TileTensor[fp8_scale_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContext)

This is a manually fused kernel that performs the following operations: - Apply RoPE to the query and the key cache (in-place). - Apply RMSNorm to the non-rope portion of the key cache (in-place). - Copy the KV latent values from PagedKVCache to a contiguous buffer. - Quantize the KV latent values to fp8. - Up-project the latent KV values to full K and V through a matmul. - Split the concatenated KV into K and V. - Perform MLA prefill.

Parameters:

  • dtype (DType): Data type of the input and output tensors.
  • fp8_dtype (DType): Data type of the fp8 input and output tensors.
  • fp8_scale_dtype (DType): Data type of the fp8 scale input and output tensors.
  • collection_t (KVCollectionT): Type of the KV collection.
  • m_scale_granularity (Int): Granularity of the scale for M dimension of the matrix multiplication.
  • n_scale_granularity (Int): Granularity of the scale for N dimension of the matrix multiplication.
  • k_scale_granularity (Int): Granularity of the scale for K dimension of the matrix multiplication.
  • mask_str (StringSlice): Mask variant.
  • score_mod_str (StringSlice): Positional encoding variant.
  • target (StringSlice): Target device.

Args:

  • output (TileTensor): Output tensor of shape [tot_seq_len, num_heads, v_head_dim].
  • q (TileTensor): Combined query tensor containing both nope and rope parts. Shape: [tot_seq_len, num_heads, qk_nope_head_dim + qk_rope_head_dim].
  • input_row_offsets (TileTensor): Indicates where each request starts and ends in q. Shape: [num_batches + 1].
  • freqs_cis (TileTensor): Precomputed RoPE frequency values for rotary position embeddings. Shape: [max_seq_len, qk_rope_head_dim].
  • kv_norm_gamma (TileTensor): RMSNorm gamma weights for normalizing the KV cache. Shape: [kv_lora_rank].
  • kv_collection (collection_t): Paged KV Cache object.
  • layer_idx (UInt32): Layer index.
  • scale (Float32): Scale for the attention calculation.
  • epsilon (Float32): Small constant for numerical stability in RMSNorm.
  • buffer_row_offsets (TileTensor): Indicates where each request's KV latent values should be stored in the contiguous K buffer. This is a 1D tensor of shape [num_batches + 1].
  • cache_offsets (TileTensor): Indicates the starting token position in the KV cache from which to copy KV latent values for each request. This is a 1D tensor of shape [num_batches + 1].
  • buffer_length (Int): The total number of tokens in the KV cache. Scalar.
  • kv_b_proj (TileTensor): Weight matrix for up-projecting the KV latent values to full K and V. Shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_latent_dim].
  • kv_b_proj_scale (TileTensor): The scale for the weight matrix. Shape varies depending on the float8_config.
  • ctx (DeviceContext): Device context.

Was this page helpful?