Mojo module
mla_decode_sm100_qkv_fp8
Native FP8 MLA decode kernel for SM100 (B200).
All-FP8 kernel: Q, K, V, and P are all FP8 e4m3 in SMEM. Uses native FP8 WGMMA (tcgen05.mma.kind::f8f6f4) for both QK and PV. Same 3-WG structure as the BF16 kernel (softmax WG, correction WG, MMA+Load+Store WG).
Q arrives as FP8 from TMA directly (like FlashInfer), no BF16 conversion. KV arrives as FP8 from TMA directly (half the bytes of BF16). P (softmax output) is written as FP8 e4m3 to a separate SMEM region. The FP8 tensorwise dequant scale is folded into the softmax QK scale.
SMEM Layout (native FP8): Q FP8: 64 x 576 x 1 = 36864 bytes (SWIZZLE_64B) KV stages: N x 64 x 576 x 1 bytes (SWIZZLE_64B, N=num_kv_stages, typically 4) P stages: N x 64 x 64 x 1 bytes (SWIZZLE_64B, separate from KV) max/li: 128 x 4 x 2 = 1024 bytes barriers: (6N+11) fixed + output barriers
Structs
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!