Mojo module
amd_4wave_schedule
Inline 4-wave schedule for AMD GPU FP8 matmul kernels.
Mirrors the hand-written _run_iter body in amd_4wave_matmul.run
op-for-op. Per half, four mini-iters of (LOAD, FRAG, MMA) with
cross-stage MMA fragment rotation in mini-iters 3-4: those frag-loads
read the cross SMEM stage to pre-load the next half's leading-quadrant
fragments while this half's last MMAs are still computing. The same
register-rotation trick the hand-written body uses to hide LDS-load
latency between halves (saves one frag-load worth of LDS latency on each
half's first MMA).
Three load-bearing structural choices vs the default ping-pong-shaped matmul schedule:
-
Mini-iter declaration order β body is
(LOAD, FRAG, MMA) * 4per K-partition (12 ops Γ 2 partitions = 24). After_construct_mma_blockspacks into MMA-centered blocks, each block holds exactly 1 frag + 1 global + 1 MMA, matching_run_iter's mini-iter shape. -
Cross-stage MMA fragment rotation β minis 3-4 emit
MMA_LOAD_A[stage=os, sub=0]andMMA_LOAD_B[stage=os, sub=0], reading from the cross stage. Same-partition MMAs no longer touch A_quad[0]/B_quad[0] after mini-iter 2, so the cross-stage frag harmlessly overwrites them with data the next partition's leading MMA will consume. -
SchedulingStrategy.IDENTITYand nodouble_buffer_reorderβ the body order is the final order. The framework'smma_block_interleave_listmatches frags to MMAs bysubtileignoringstage; running it on a cross-stage body would place a wrong-stage frag before the same-partition MMA. Bypassing the reorder preserves the per-mini-iter grouping.
The framework's default double_buffer_edge_rules include Phase 1
(FRAGMENT_LOAD β COMPUTE, same_half=True, use_config_match=True).
Cross-stage rotation is a cross-half register flow β half 0's mini-3-4
frags feed half 1's MMA(0,), and half 1's feed next-iter half 0's
MMA(0,). derive_edges overrides the default to append the
cross-half rules so wait derivation knows to drain frag lgkm at the
half boundary.
The framework's auto-prologue (derive_prologue_from_program) infers
initial frag-loads from the body's same-stage b.fan[2] ops; with this
body's cross-stage sub=0 frags it would read the cross stage at
k_base=0 (=k=BK data) instead of the same stage (=k=0 data) needed for
the first main iter's MMA[0,0]+MMA[0,1] (A_quad[0]+B_quad[0]). The
calling kernel inserts an explicit bootstrap pair after the framework
prologue to overwrite those quadrants with same-stage sub=0 data.
comptime valuesβ
COMPUTEβ
comptime COMPUTE = FourWaveOps.COMPUTE.value
Integer tag for the reserved generic compute op.
FOUR_WAVE_MINI_ITERSβ
comptime FOUR_WAVE_MINI_ITERS = List(MiniIterSpec(LOAD_A, 0, 0, MMA_LOAD_B, 1, 1, False, 0, 0), MiniIterSpec(LOAD_B, 1, 0, MMA_LOAD_A, 0, 1, False, 0, 1), MiniIterSpec(LOAD_B, 1, 1, MMA_LOAD_A, 0, 0, True, 1, 0), MiniIterSpec(LOAD_A, 0, 1, MMA_LOAD_B, 1, 0, True, 1, 1), __list_literal__=NoneType(None))
4-wave's 4 mini-iters per K-partition.
Same shape across both partitions β only the SMEM stage flips. Mini-3/4's frag-loads read from the cross stage to pre-load the next partition's leading quadrants (A_quad[0] / B_quad[0]) while this partition's last MMAs still issue.
LOAD_Aβ
comptime LOAD_A = FourWaveOps.LOAD_A.value
Integer tag for A DRAM->LDS prefetch.
LOAD_Bβ
comptime LOAD_B = FourWaveOps.LOAD_B.value
Integer tag for B DRAM->LDS prefetch.
MMAβ
comptime MMA = FourWaveOps.MMA.value
Integer tag for the MFMA quadrant compute op.
MMA_LOAD_Aβ
comptime MMA_LOAD_A = FourWaveOps.MMA_LOAD_A.value
Integer tag for A LDS->register frag-load.
MMA_LOAD_Bβ
comptime MMA_LOAD_B = FourWaveOps.MMA_LOAD_B.value
Integer tag for B LDS->register frag-load.
Structsβ
- β
FourWaveOps: Op tags for the 4-wave matmul kernel. - β
MiniIterSpec: One mini-iter of the 4-wave body: (DRAM prefetch, frag-load, MMA). - β
Pipeline4Wave: 4-wave pipeline schedule with cross-stage register rotation.
Functionsβ
- β
build_schedule: Compiles the 4-wave pipeline schedule.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!