Mojo module
mla_decode_sm100_qkv_fp8_per_token_scale_rope_aware
SnapMLA FP8+BF16 MLA decode kernel for SM100 (B200).
Split content/rope kernel with per-token FP8 scaling (Steps 1+2):
- Content (nope): FP8 e4m3 for both Q_nope and K_nope (512 dims)
- Rope: BF16 for both Q_rope and K_rope (64 dims)
- P (softmax output): FP8 e4m3 reusing KV rope SMEM (P_i maps to rope stage i)
- V: FP8 e4m3 content-only (512 dims)
Uses native FP8 WGMMA (tcgen05.mma.kind::f8f6f4) for content QK and PV. Uses BF16 WGMMA (tcgen05.mma.kind::f16) for rope QK. Same 3-WG structure as the BF16 kernel (softmax WG, correction WG, MMA+Load+Store WG).
KV Cache Layout (640 bytes per row): Bytes 0-511: FP8 content (512 dims, kv_lora_rank) Bytes 512-639: BF16 rope (64 dims × 2 bytes)
Per-Token FP8 Scaling (SnapMLA Approach): Each KV token t has a per-token scale sigma_KV[t] (one float32 value). In MLA's absorbed mode, K and V derive from the same latent c_KV, so sigma_KV[t] is shared between K and V dequantization.
sigma_Q is per-query-token: each Q position has its own float32 scale. All BM=64 heads in a CTA share the same Q token, so sigma_Q is constant per CTA. It is folded into scale_log2e inside the Softmax function: scale_log2e = (1/sqrt(d_qk)) * sigma_Q[q_token_idx]
QK scoring: After reading combined scores S = content_raw + rope_raw from TMEM, each column t is multiplied by sigma_KV[t] BEFORE the log2e softmax scaling. This is mathematically exact under Scale Domain Alignment (Eq. 6): Q_rope and K_rope are pre-divided by their respective content scales before entering the kernel, so the uniform sigma application is correct. PV dequant: Before writing P to FP8 SMEM, each column t of the softmax output is multiplied by sigma_KV[t], pre-fusing the V dequant scale: P'[t] = P[t] * sigma_KV[t], so PV MMA computes sum_t P'[t] * V_fp8[t] = sum_t P[t] * sigma_KV[t] * V_fp8[t].
MMA in QK:
- FP8 content MMA: Q_nope(FP8) × K_nope(FP8) → S in TMEM (c_scale=0, 8 blocks)
- BF16 rope MMA: Q_rope(BF16) × K_rope(BF16) → accumulate onto S (c_scale=1, 1 block)
SMEM Layout: Q_nope FP8: 64 × 512 × 1 = 32768 bytes (SWIZZLE_64B) Q_rope BF16: 64 × 64 × 2 = 8192 bytes (SWIZZLE_128B) KV content: N × 64 × 512 × 1 bytes (SWIZZLE_64B, N=num_kv_stages) KV rope: N × 64 × 64 × 2 bytes (SWIZZLE_128B) P stages: reuses KV rope region (P_i in rope stage i; 4096B FP8 fits in 8192B BF16) max/li: 128 × 4 × 3 = 1536 bytes per-tok scales: N × 64 × 1 × 4 bytes (float32 sigma_KV per KV token) barriers: (6N+11) fixed + output barriers
Structs
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!