Skip to main content

Mojo module

mla_graph

Functions

  • mla_decode_branch_fp8: This is a manually fused kernel that performs the following operations: - Project q_nope to kv_latent_dim through a fp8 batched matmul: q_nope_proj = q_nope_t @ w_uk. - Concatenate q_nope_proj and q_rope: q_full = concat(q_nope_proj, q_rope, axis=2). - Perform MLA decode. - Project raw_output to v_head_dim through another fp8 batched matmul: output = raw_output_t @ w_uv.
  • mla_prefill_branch_fp8: This is a manually fused kernel that performs the following operations: - Copy the KV latent values from PagedKVCache to a contiguous buffer. - Quantize the KV latent values to fp8. - Up-project the latent KV values to full K and V through a matmul. - Split the concatenated KV into K and V. - Perform MLA prefill.
  • mla_prefill_decode_graph_fp8: This is a manually fused kernel that performs the following operations: - Perform MLA prefill or decode based on the maximum sequence length.
  • quantize_and_bmm_fp8_helper: Helper function to quantize and perform a batched matrix multiplication. This function uses the transposed view of the input tensor a.
  • transpose_helper: Helper function to transpose a tensor from [B, N, K] to [N, B, K] (or vice versa).

Was this page helpful?