Skip to main content

Mojo module

mha_fp8_kv

SM100 (B200) ragged-paged MHA with FP8 KV cache + per-block fp32 scales.

Target hardware family: NVIDIA SM100 (Blackwell, B200).

Implements the correctness MVP of the SM100 MHA fp8-KV decode path. Architecture:

bf16 Q
   |
   v
+-------------------------------+
| mha_fp8_kv_decode_ragged(...) |
+-------------------------------+
   |
   | (1) dequant_paged_fp8_to_bf16:
   |     per-page, per-token, per-head_dim_block convert
   |     fp8 K, fp8 V -> staging bf16 K, V, applying
   |     `bf16(float(fp8) * fp32_scale[block(head_dim)])`
   v
+-------------------------------+
| flash_attention_ragged (bf16) |
| -> existing FA4 SM100 path    |
+-------------------------------+
   |
   v
bf16 output

A future fused convert+scale-apply variant (mirroring mla_decode_kv_fp8.mojo's kv_load2cvt_pipe โ†’ kv_cvt2mma_pipe staging) would replace the external dequant pass with an in-pipeline convert WG. The surface exposed here is stable across that change.

Constraints:

  • The dequant kernel only queries load_scale (and the underlying PagedKVCache._get_scale_idx) at block-start head_dim indices (d // g) * g, matching the floordiv semantics in _get_scale_idx.
  • The dequant operates on FP8 paged blocks indexed by the same lookup_table as the bf16 staging buffer.

Functionsโ€‹