IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mla_decode_qkv_fp8

Native FP8 MLA decode kernel for SM100 (B200).

All-FP8 kernel: Q, K, V, and P are all FP8 e4m3 in SMEM. Uses native FP8 WGMMA (tcgen05.mma.kind::f8f6f4) for both QK and PV. Same 3-WG structure as the BF16 kernel (softmax WG, correction WG, MMA+Load+Store WG).

Q arrives as FP8 from TMA directly (like FlashInfer), no BF16 conversion. KV arrives as FP8 from TMA directly (half the bytes of BF16). P (softmax output) is written as FP8 e4m3 to a separate SMEM region. The FP8 tensorwise dequant scale is folded into the softmax QK scale.

SMEM Layout (native FP8): Q FP8: 64 x 576 x 1 = 36864 bytes (SWIZZLE_64B) KV stages: N x 64 x 576 x 1 bytes (SWIZZLE_64B, N=num_kv_stages, typically 4) P stages: N x 64 x 64 x 1 bytes (SWIZZLE_64B, separate from KV) max/li: 128 x 4 x 2 = 1024 bytes barriers: (6N+11) fixed + output barriers

Structs