Skip to main content

Mojo module

amd_4wave_conv

4-wave FP8 implicit-GEMM convolution for AMD MI355X (CDNA4).

Host launcher: amd_4wave_conv()

The kernel body lives on AMD4WaveMatmul.run_conv2d (in linalg.matmul.gpu.amd.amd_4wave_matmul) β€” the conv2d and matmul share the same struct, the same 4-warp 2x2 quadrant layout, the same MFMA shapes, and the same software-pipeline schedule. They differ only in the A-operand loader: matmul uses TileLoaderLDS (linear [M, K]); conv uses TileLoaderLDSIm2col (NHWC input + in-line im2col address math). This file just bundles the launcher and the HK chiplet/L2-swizzle helper.

Supported configuration space:

  • Filter R Γ— S: any R, S >= 1.
  • Stride: any >= 1.
  • Dilation: any >= 1.
  • Pad: any >= 0 (halo lanes route to the SRD-OOB sentinel).
  • Input dtype: FP8 (E4M3FN), BF16, or FP16. Output dtype is BF16 (FP8 in, BF16 out) or matches the input for BF16/FP16. All dtypes route through the framework-scheduled body.
  • C_in: any positive value. The loader picks between a fast-path "uniform substrip per call" code when C_in % BK == 0 and BK <= C_in, and a slower per-lane-substrip code otherwise (e.g. ResNet stem with C_in = 64 and BK = 128).
  • K = RSC_in: must be a multiple of 2BK = 256 (the 4-wave schedule's two-stage prologue requirement). When RSC_in isn't aligned, the caller K-pads the filter buffer to a multiple of 2BK by zero-filling the trailing K rows; pass the real C_in via the C_in comptime kwarg on the launcher so the loader uses the unpadded value in its address math. Zero filter rows make the MMA contribution for padded K columns 0 regardless of what the A loader produces.
  • num_splits: 1 (no split-K).
  • BM Γ— BN: 64Γ—64, 128Γ—128, or 128Γ—256 (from the matmul's auto-pick; overridable via block_m_override / block_n_override).

The conv launcher takes the 4D NHWC input directly. The filter is a 2D [Cout, K_padded] tile-tensor in FRSC order (filter row f column k = r*S*C_in + s*C_in + c is weight[f, r, s, c]). The output is a 2D [N*H_out*W_out, Cout] view of the NHWC output buffer β€” for packed NHWC layouts the 2D view aliases the same bytes exactly. See max/kernels/test/gpu/nn/test_amd_4wave_conv*.mojo for end-to-end correctness coverage.

Functions​