Skip to main content

Mojo module

amd_4wave_split_k_matmul

Single-launch split-K wrapper for the 4-wave FP8 matmul.

Targets the small-M decode regime where the base 4-wave kernel doesn't saturate 304 MI355X CUs at the natural launch geometry (M ≀ 64 N=K=8192 BM=BN=64 β†’ 128 WGs at full K, ~42% of CUs in 1 wave; per-WG K-loop dominates wall-clock).

Splits the K dimension into num_splits chunks and launches all splits in ONE kernel invocation by extending the launch grid with grid_dim.z = num_splits. Each WG decodes its split from block_idx.z and writes to its own slot in a stacked-M workspace of shape (num_splits * M, N) row-major float32. A subsequent reduce kernel sums the num_splits partials and casts to the final output dtype.

Single-launch matters: a multi-stream approach (deleted Apr 2026) lost 5–60% across decode/prefill due to record_event/wait_for sync overhead exceeding the per-kernel runtime. Same-launch lets the GPU schedule all splits concurrently with no host-side sync between them, paying only the one reduce launch on top.

Structs​

Functions​