Mojo module
amd_4wave_conv
4-wave FP8 implicit-GEMM convolution for AMD MI355X (CDNA4).
Host launcher: amd_4wave_conv()
The kernel body lives on AMD4WaveMatmul.run_conv2d (in
linalg.matmul.gpu.amd.amd_4wave_matmul) β the conv2d and matmul
share the same struct, the same 4-warp 2x2 quadrant layout, the same
MFMA shapes, and the same software-pipeline schedule. They differ
only in the A-operand loader: matmul uses TileLoaderLDS (linear
[M, K]); conv uses TileLoaderLDSIm2col (NHWC input + in-line
im2col address math). This file just bundles the launcher and the
HK chiplet/L2-swizzle helper.
Supported configuration space:
- Filter R Γ S: any R, S >= 1.
- Stride: any >= 1.
- Dilation: any >= 1.
- Pad: any >= 0 (halo lanes route to the SRD-OOB sentinel).
- Input dtype: FP8 (E4M3FN), BF16, or FP16. Output dtype is BF16 (FP8 in, BF16 out) or matches the input for BF16/FP16. All dtypes route through the framework-scheduled body.
- C_in: any positive value. The loader picks between a fast-path
"uniform substrip per call" code when
C_in % BK == 0andBK <= C_in, and a slower per-lane-substrip code otherwise (e.g. ResNet stem with C_in = 64 and BK = 128). - K = RSC_in: must be a multiple of 2BK = 256 (the 4-wave
schedule's two-stage prologue requirement). When RSC_in isn't
aligned, the caller K-pads the filter buffer to a multiple of
2BK by zero-filling the trailing K rows; pass the real C_in via
the
C_incomptime kwarg on the launcher so the loader uses the unpadded value in its address math. Zero filter rows make the MMA contribution for padded K columns 0 regardless of what the A loader produces. - num_splits: 1 (no split-K).
- BM Γ BN: 64Γ64, 128Γ128, or 128Γ256 (from the matmul's auto-pick;
overridable via
block_m_override/block_n_override).
The conv launcher takes the 4D NHWC input directly. The filter is a
2D [Cout, K_padded] tile-tensor in FRSC order (filter row f
column k = r*S*C_in + s*C_in + c is weight[f, r, s, c]). The
output is a 2D [N*H_out*W_out, Cout] view of the NHWC output
buffer β for packed NHWC layouts the 2D view aliases the same bytes
exactly. See max/kernels/test/gpu/nn/test_amd_4wave_conv*.mojo for
end-to-end correctness coverage.
Functionsβ
- β
amd_4wave_conv: Launches the 4-wave implicit-GEMM convolution on the device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!