IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

amd_4wave_schedule

Inline 4-wave schedule for AMD GPU matmul / implicit-GEMM conv kernels.

Per half, four mini-iters of (LOAD, FRAG, MMA) with cross-stage MMA fragment rotation in mini-iters 3-4: those frag-loads read the cross SMEM stage to pre-load the next half's leading-quadrant fragments while this half's last MMAs are still computing β€” a register-rotation trick that hides LDS-load latency between halves (saves one frag-load worth of LDS latency on each half's first MMA).

Three load-bearing structural choices vs the default ping-pong-shaped matmul schedule:

  1. Mini-iter declaration order β€” body is (LOAD, FRAG, MMA) * 4 per K-partition (12 ops Γ— 2 partitions = 24). After _construct_mma_blocks packs into MMA-centered blocks, each block holds exactly 1 frag + 1 global + 1 MMA, matching _run_iter's mini-iter shape.

  2. Cross-stage MMA fragment rotation β€” minis 3-4 emit MMA_LOAD_A[stage=os, sub=0] and MMA_LOAD_B[stage=os, sub=0], reading from the cross stage. Same-partition MMAs no longer touch A_quad[0]/B_quad[0] after mini-iter 2, so the cross-stage frag harmlessly overwrites them with data the next partition's leading MMA will consume.

  3. SchedulingStrategy.IDENTITY and no double_buffer_reorder β€” the body order is the final order. The framework's mma_block_interleave_list matches frags to MMAs by subtile ignoring stage; running it on a cross-stage body would place a wrong-stage frag before the same-partition MMA. Bypassing the reorder preserves the per-mini-iter grouping.

The framework's default double_buffer_edge_rules include Phase 1 (FRAGMENT_LOAD β†’ COMPUTE, same_half=True, use_config_match=True). Cross-stage rotation is a cross-half register flow β€” half 0's mini-3-4 frags feed half 1's MMA(0,), and half 1's feed next-iter half 0's MMA(0,). derive_edges overrides the default to append the cross-half rules so wait derivation knows to drain frag lgkm at the half boundary.

The framework's auto-prologue (derive_prologue_from_program) infers initial frag-loads from the body's same-stage b.fan[2] ops; with this body's cross-stage sub=0 frags it would read the cross stage at k_base=0 (=k=BK data) instead of the same stage (=k=0 data) needed for the first main iter's MMA[0,0]+MMA[0,1] (A_quad[0]+B_quad[0]). The calling kernel inserts an explicit bootstrap pair after the framework prologue to overwrite those quadrants with same-stage sub=0 data.

comptime values​

COMPUTE​

comptime COMPUTE = FourWaveOps.COMPUTE.value

Integer tag for the reserved generic compute op.

FOUR_WAVE_MINI_ITERS​

comptime FOUR_WAVE_MINI_ITERS = List(MiniIterSpec(LOAD_A, Int(0), Int(0), MMA_LOAD_B, Int(1), Int(1), False, Int(0), Int(0)), MiniIterSpec(LOAD_B, Int(1), Int(0), MMA_LOAD_A, Int(0), Int(1), False, Int(0), Int(1)), MiniIterSpec(LOAD_B, Int(1), Int(1), MMA_LOAD_A, Int(0), Int(0), True, Int(1), Int(0)), MiniIterSpec(LOAD_A, Int(0), Int(1), MMA_LOAD_B, Int(1), Int(0), True, Int(1), Int(1)), __list_literal__=NoneType(None))

4-wave's 4 mini-iters per K-partition.

Same shape across both partitions β€” only the SMEM stage flips. Mini-3/4's frag-loads read from the cross stage to pre-load the next partition's leading quadrants (A_quad[0] / B_quad[0]) while this partition's last MMAs still issue.

LOAD_A​

comptime LOAD_A = FourWaveOps.LOAD_A.value

Integer tag for A DRAM->LDS prefetch.

LOAD_B​

comptime LOAD_B = FourWaveOps.LOAD_B.value

Integer tag for B DRAM->LDS prefetch.

MMA​

comptime MMA = FourWaveOps.MMA.value

Integer tag for the MFMA quadrant compute op.

MMA_LOAD_A​

comptime MMA_LOAD_A = FourWaveOps.MMA_LOAD_A.value

Integer tag for A LDS->register frag-load.

MMA_LOAD_B​

comptime MMA_LOAD_B = FourWaveOps.MMA_LOAD_B.value

Integer tag for B LDS->register frag-load.

Structs​

Functions​