Mojo module
mla_graph
Functions
-
fused_rope_rmsnorm_kernel: Fused GPU kernel that applies RoPE to query projections and RMSNorm to KV cache. -
mla_decode_branch_bf16: BF16 MLA decode path. -
mla_decode_branch_fp8: This is a manually fused kernel that performs the following operations: - Apply RoPE to the query and the key cache (in-place). - Apply RMSNorm to the non-rope portion of the key cache (in-place). - Project q_nope to kv_latent_dim through a fp8 batched matmul: q_nope_proj = q_nope_t @ w_uk. - Concatenate q_nope_proj and q_rope: q_full = concat(q_nope_proj, q_rope, axis=2). - Perform MLA decode. - Project raw_output to v_head_dim through another fp8 batched matmul: output = raw_output_t @ w_uv. -
mla_fused_rope_rmsnorm: Launches the fused RoPE and RMSNorm kernel for MLA attention. -
mla_prefill_branch_bf16: BF16 MLA prefill path. -
mla_prefill_branch_fp8: This is a manually fused kernel that performs the following operations: - Apply RoPE to the query and the key cache (in-place). - Apply RMSNorm to the non-rope portion of the key cache (in-place). - Copy the KV latent values from PagedKVCache to a contiguous buffer. - Quantize the KV latent values to fp8. - Up-project the latent KV values to full K and V through two matmuls. - Perform MLA prefill. -
mla_prefill_decode_graph_bf16: BF16 MLA prefill/decode graph. -
mla_prefill_decode_graph_fp8: This is a manually fused kernel that performs the following operations: - Perform MLA prefill or decode based on the maximum sequence length. -
quantize_and_bmm_fp8_helper: Helper function to quantize and perform a batched matrix multiplication. This function uses the transposed view of the input tensora.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!