Mojo module
amd_4wave_split_k_matmul
Single-launch split-K wrapper for the 4-wave FP8 matmul.
Targets the small-M decode regime where the base 4-wave kernel doesn't saturate 304 MI355X CUs at the natural launch geometry (M β€ 64 N=K=8192 BM=BN=64 β 128 WGs at full K, ~42% of CUs in 1 wave; per-WG K-loop dominates wall-clock).
Splits the K dimension into num_splits chunks and launches all
splits in ONE kernel invocation by extending the launch grid with
grid_dim.z = num_splits. Each WG decodes its split from block_idx.z
and writes to its own slot in a stacked-M workspace of shape
(num_splits * M, N) row-major float32. A subsequent reduce kernel
sums the num_splits partials and casts to the final output dtype.
Single-launch matters: a multi-stream approach (deleted Apr 2026) lost 5β60% across decode/prefill due to record_event/wait_for sync overhead exceeding the per-kernel runtime. Same-launch lets the GPU schedule all splits concurrently with no host-side sync between them, paying only the one reduce launch on top.
Structsβ
- β
SplitKWorkspace: Pre-allocated scratch for repeated split-K launches.
Functionsβ
- β
amd_4wave_split_k_matmul: Launches the single-launch split-K 4-wave matmul on the device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!