Mojo module
mha_fp8_kv
SM100 (B200) ragged-paged MHA with FP8 KV cache + per-block fp32 scales.
Target hardware family: NVIDIA SM100 (Blackwell, B200).
Implements the correctness MVP of the SM100 MHA fp8-KV decode path. Architecture:
bf16 Q
|
v
+-------------------------------+
| mha_fp8_kv_decode_ragged(...) |
+-------------------------------+
|
| (1) dequant_paged_fp8_to_bf16:
| per-page, per-token, per-head_dim_block convert
| fp8 K, fp8 V -> staging bf16 K, V, applying
| `bf16(float(fp8) * fp32_scale[block(head_dim)])`
v
+-------------------------------+
| flash_attention_ragged (bf16) |
| -> existing FA4 SM100 path |
+-------------------------------+
|
v
bf16 outputA future fused convert+scale-apply variant (mirroring
mla_decode_kv_fp8.mojo's kv_load2cvt_pipe โ kv_cvt2mma_pipe
staging) would replace the external dequant pass with an in-pipeline
convert WG. The surface exposed here is stable across that change.
Constraints:
- The dequant kernel only queries
load_scale(and the underlyingPagedKVCache._get_scale_idx) at block-start head_dim indices(d // g) * g, matching the floordiv semantics in_get_scale_idx. - The dequant operates on FP8 paged blocks indexed by the same
lookup_tableas the bf16 staging buffer.
Functionsโ
- โ
dequant_paged_fp8_kv_to_bf16: Enqueue the dequant kernel for one layer of a paged FP8 KV cache.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!