Mojo module
amd_4wave_matmul
4-wave FP8 matmul for AMD MI355X (CDNA4).
Entry point: AMD4WaveMatmul.run() Host launcher: amd_4wave_matmul()
Hand-written line-by-line port of HipKittens FP8_4wave's
matmul_device_1024 (BM=64) and matmul_device_2048 (BM=128).
Mirrors the source kernel's exact structure:
- 4 mini-iters per loop iter, each with
G_load + frag_load + mma_ABt - Cross-stage register rotation: a[0]/b[0] are reloaded mid-iter from
nextstage so iter k+1's first MMA can fire without waiting on LDS - Explicit s_waitcnt vmcnt(N) values matching the source's empirical tuning (vmcnt(7) / vmcnt(6) / vmcnt(4) / vmcnt(2) / vmcnt(0))
- Mid-iter s_barrier between mini-iters 2 and 3
- 2-iter epilogue drain
The SMEM organization, TileLoaderLDS, and QuadrantMmaOp are reused across the hand-written and framework-driven body strategies; the two paths share all scaffolding except the loop body itself.
Structsβ
- β
AMD4WaveMatmul: Hand-written 4-warp 2x2 inline-MMA matmul for AMD MI355X. - β
KernelConfig: Block/warp/MMA shape configuration for 4-wave-simple kernels.
Functionsβ
- β
amd_4wave_matmul: Launches the hand-written 4-wave matmul on the device. - β
amd_4wave_scheduled_matmul: Launches the schedule-compiler-driven 4-wave matmul on the device. - β
s_barrier: - β
s_setprio:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!