Skip to main content

Mojo module

amd_4wave_schedule

Inline 4-wave schedule for AMD GPU FP8 matmul kernels.

Mirrors the hand-written _run_iter body in amd_4wave_matmul.run op-for-op. Per half, four mini-iters of (LOAD, FRAG, MMA) with cross-stage MMA fragment rotation in mini-iters 3-4: those frag-loads read the cross SMEM stage to pre-load the next half's leading-quadrant fragments while this half's last MMAs are still computing. The same register-rotation trick the hand-written body uses to hide LDS-load latency between halves (saves one frag-load worth of LDS latency on each half's first MMA).

Three load-bearing structural choices vs the default ping-pong-shaped matmul schedule:

  1. Mini-iter declaration order β€” body is (LOAD, FRAG, MMA) * 4 per K-partition (12 ops Γ— 2 partitions = 24). After _construct_mma_blocks packs into MMA-centered blocks, each block holds exactly 1 frag + 1 global + 1 MMA, matching _run_iter's mini-iter shape.

  2. Cross-stage MMA fragment rotation β€” minis 3-4 emit MMA_LOAD_A[stage=os, sub=0] and MMA_LOAD_B[stage=os, sub=0], reading from the cross stage. Same-partition MMAs no longer touch A_quad[0]/B_quad[0] after mini-iter 2, so the cross-stage frag harmlessly overwrites them with data the next partition's leading MMA will consume.

  3. SchedulingStrategy.IDENTITY and no double_buffer_reorder β€” the body order is the final order. The framework's mma_block_interleave_list matches frags to MMAs by subtile ignoring stage; running it on a cross-stage body would place a wrong-stage frag before the same-partition MMA. Bypassing the reorder preserves the per-mini-iter grouping.

The framework's default double_buffer_edge_rules include Phase 1 (FRAGMENT_LOAD β†’ COMPUTE, same_half=True, use_config_match=True). Cross-stage rotation is a cross-half register flow β€” half 0's mini-3-4 frags feed half 1's MMA(0,), and half 1's feed next-iter half 0's MMA(0,). derive_edges overrides the default to append the cross-half rules so wait derivation knows to drain frag lgkm at the half boundary.

The framework's auto-prologue (derive_prologue_from_program) infers initial frag-loads from the body's same-stage b.fan[2] ops; with this body's cross-stage sub=0 frags it would read the cross stage at k_base=0 (=k=BK data) instead of the same stage (=k=0 data) needed for the first main iter's MMA[0,0]+MMA[0,1] (A_quad[0]+B_quad[0]). The calling kernel inserts an explicit bootstrap pair after the framework prologue to overwrite those quadrants with same-stage sub=0 data.

comptime values​

COMPUTE​

comptime COMPUTE = FourWaveOps.COMPUTE.value

Integer tag for the reserved generic compute op.

FOUR_WAVE_MINI_ITERS​

comptime FOUR_WAVE_MINI_ITERS = List(MiniIterSpec(LOAD_A, 0, 0, MMA_LOAD_B, 1, 1, False, 0, 0), MiniIterSpec(LOAD_B, 1, 0, MMA_LOAD_A, 0, 1, False, 0, 1), MiniIterSpec(LOAD_B, 1, 1, MMA_LOAD_A, 0, 0, True, 1, 0), MiniIterSpec(LOAD_A, 0, 1, MMA_LOAD_B, 1, 0, True, 1, 1), __list_literal__=NoneType(None))

4-wave's 4 mini-iters per K-partition.

Same shape across both partitions β€” only the SMEM stage flips. Mini-3/4's frag-loads read from the cross stage to pre-load the next partition's leading quadrants (A_quad[0] / B_quad[0]) while this partition's last MMAs still issue.

LOAD_A​

comptime LOAD_A = FourWaveOps.LOAD_A.value

Integer tag for A DRAM->LDS prefetch.

LOAD_B​

comptime LOAD_B = FourWaveOps.LOAD_B.value

Integer tag for B DRAM->LDS prefetch.

MMA​

comptime MMA = FourWaveOps.MMA.value

Integer tag for the MFMA quadrant compute op.

MMA_LOAD_A​

comptime MMA_LOAD_A = FourWaveOps.MMA_LOAD_A.value

Integer tag for A LDS->register frag-load.

MMA_LOAD_B​

comptime MMA_LOAD_B = FourWaveOps.MMA_LOAD_B.value

Integer tag for B LDS->register frag-load.

Structs​

Functions​