IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

amd_4wave_split_k_matmul

Single-launch split-K wrapper for the 4-wave FP8 matmul.

Targets the small-M decode regime where the base 4-wave kernel doesn't saturate 304 MI355X CUs at the natural launch geometry (M ≀ 64 N=K=8192 BM=BN=64 β†’ 128 WGs at full K, ~42% of CUs in 1 wave; per-WG K-loop dominates wall-clock).

Splits the K dimension into num_splits chunks and launches all splits in ONE kernel invocation by extending the launch grid with grid_dim.z = num_splits. Each WG decodes its split from block_idx.z and writes to its own slot in a stacked-M workspace of shape (num_splits * M, N) row-major float32. A subsequent reduce kernel sums the num_splits partials and casts to the final output dtype.

Single-launch matters: a multi-stream approach (deleted Apr 2026) lost 5–60% across decode/prefill due to record_event/wait_for sync overhead exceeding the per-kernel runtime. Same-launch lets the GPU schedule all splits concurrently with no host-side sync between them, paying only the one reduce launch on top.

Structs​

Functions​