Mojo function
mla_prefill_branch_fp8
mla_prefill_branch_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_shape_types=output.element_shape_types], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_shape_types=q.element_shape_types], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, linear_idx_type=input_row_offsets.linear_idx_type, element_shape_types=input_row_offsets.element_shape_types], freqs_cis: TileTensor[freqs_cis.dtype, freqs_cis.LayoutType, freqs_cis.origin, linear_idx_type=freqs_cis.linear_idx_type, element_shape_types=freqs_cis.element_shape_types], kv_norm_gamma: TileTensor[kv_norm_gamma.dtype, kv_norm_gamma.LayoutType, kv_norm_gamma.origin, linear_idx_type=kv_norm_gamma.linear_idx_type, element_shape_types=kv_norm_gamma.element_shape_types], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, buffer_row_offsets: TileTensor[DType.uint32, buffer_row_offsets.LayoutType, buffer_row_offsets.origin, linear_idx_type=buffer_row_offsets.linear_idx_type, element_shape_types=buffer_row_offsets.element_shape_types], cache_offsets: TileTensor[DType.uint32, cache_offsets.LayoutType, cache_offsets.origin, linear_idx_type=cache_offsets.linear_idx_type, element_shape_types=cache_offsets.element_shape_types], buffer_length: Int, w_k: TileTensor[fp8_dtype, w_k.LayoutType, w_k.origin, linear_idx_type=w_k.linear_idx_type, element_shape_types=w_k.element_shape_types], w_k_scale: TileTensor[fp8_scale_dtype, w_k_scale.LayoutType, w_k_scale.origin, linear_idx_type=w_k_scale.linear_idx_type, element_shape_types=w_k_scale.element_shape_types], w_uv: TileTensor[fp8_dtype, w_uv.LayoutType, w_uv.origin, linear_idx_type=w_uv.linear_idx_type, element_shape_types=w_uv.element_shape_types], w_uv_scale: TileTensor[fp8_scale_dtype, w_uv_scale.LayoutType, w_uv_scale.origin, linear_idx_type=w_uv_scale.linear_idx_type, element_shape_types=w_uv_scale.element_shape_types], ctx: DeviceContext)
This is a manually fused kernel that performs the following operations: - Apply RoPE to the query and the key cache (in-place). - Apply RMSNorm to the non-rope portion of the key cache (in-place). - Copy the KV latent values from PagedKVCache to a contiguous buffer. - Quantize the KV latent values to fp8. - Up-project the latent KV values to full K and V through two matmuls. - Perform MLA prefill.
Parameters:
- dtype (
DType): Data type of the input and output tensors. - fp8_dtype (
DType): Data type of the fp8 input and output tensors. - fp8_scale_dtype (
DType): Data type of the fp8 scale input and output tensors. - collection_t (
KVCollectionT): Type of the KV collection. - m_scale_granularity (
Int): Granularity of the scale for M dimension of the matrix multiplication. - n_scale_granularity (
Int): Granularity of the scale for N dimension of the matrix multiplication. - k_scale_granularity (
Int): Granularity of the scale for K dimension of the matrix multiplication. - mask_str (
StringSlice): Mask variant. - score_mod_str (
StringSlice): Positional encoding variant. - target (
StringSlice): Target device.
Args:
- output (
TileTensor): Output tensor of shape [tot_seq_len, num_heads, v_head_dim]. - q (
TileTensor): Combined query tensor containing both nope and rope parts. Shape: [tot_seq_len, num_heads, qk_nope_head_dim + qk_rope_head_dim]. - input_row_offsets (
TileTensor): Indicates where each request starts and ends inq. Shape: [num_batches + 1]. - freqs_cis (
TileTensor): Precomputed RoPE frequency values for rotary position embeddings. Shape: [max_seq_len, qk_rope_head_dim]. - kv_norm_gamma (
TileTensor): RMSNorm gamma weights for normalizing the KV cache. Shape: [kv_lora_rank]. - kv_collection (
collection_t): Paged KV Cache object. - layer_idx (
UInt32): Layer index. - scale (
Float32): Scale for the attention calculation. - epsilon (
Float32): Small constant for numerical stability in RMSNorm. - buffer_row_offsets (
TileTensor): Indicates where each request's KV latent values should be stored in the contiguous K buffer. This is a 1D tensor of shape [num_batches + 1]. - cache_offsets (
TileTensor): Indicates the starting token position in the KV cache from which to copy KV latent values for each request. This is a 1D tensor of shape [num_batches + 1]. - buffer_length (
Int): The total number of tokens in the KV cache. Scalar. - w_k (
TileTensor): Weight matrix for up-projecting the latent cache to full K. Shape: [num_heads * qk_nope_head_dim, kv_latent_dim]. - w_k_scale (
TileTensor): Scale tensor forw_k. - w_uv (
TileTensor): Weight tensor for projecting latent values to V. Shape: [num_heads, v_head_dim, kv_latent_dim]. - w_uv_scale (
TileTensor): Scale tensor forw_uv. - ctx (
DeviceContext): Device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!