IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

amd_4wave_conv

4-wave FP8 implicit-GEMM convolution for AMD MI355X (CDNA4).

Host launcher: amd_4wave_conv()

The kernel body lives on AMD4WaveMatmul.run_conv2d (in linalg.matmul.gpu.amd.amd_4wave_matmul) β€” the conv2d and matmul share the same struct, the same 4-warp 2x2 quadrant layout, the same MFMA shapes, and the same software-pipeline schedule. They differ only in the A-operand loader: matmul uses TileLoaderLDS (linear [M, K]); conv uses TileLoaderLDSIm2col (NHWC input + in-line im2col address math). This file just bundles the launcher and the HK chiplet/L2-swizzle helper.

Supported configuration space:

  • Filter R Γ— S: any R, S >= 1.
  • Stride: any >= 1.
  • Dilation: any >= 1.
  • Pad: any >= 0 (halo lanes route to the SRD-OOB sentinel).
  • Input dtype: FP8 (E4M3FN), BF16, or FP16. Output dtype is BF16 (FP8 in, BF16 out) or matches the input for BF16/FP16. All dtypes route through the framework-scheduled body.
  • C_in: any positive value. The loader picks between a fast-path "uniform substrip per call" code when C_in % BK == 0 and BK <= C_in, and a slower per-lane-substrip code otherwise (e.g. ResNet stem with C_in = 64 and BK = 128).
  • K = RSC_in: must be a multiple of 2BK = 256 (the 4-wave schedule's two-stage prologue requirement). When RSC_in isn't aligned, the caller K-pads the filter buffer to a multiple of 2BK by zero-filling the trailing K rows; pass the real C_in via the C_in comptime kwarg on the launcher so the loader uses the unpadded value in its address math. Zero filter rows make the MMA contribution for padded K columns 0 regardless of what the A loader produces.
  • num_splits: 1 (no split-K).
  • BM Γ— BN: 64Γ—64, 128Γ—128, or 128Γ—256 (from the matmul's auto-pick; overridable via block_m_override / block_n_override).

The conv launcher takes the 4D NHWC input directly. The filter is a 2D [Cout, K_padded] tile-tensor in FRSC order (filter row f column k = r*S*C_in + s*C_in + c is weight[f, r, s, c]). The output is a 2D [N*H_out*W_out, Cout] view of the NHWC output buffer β€” for packed NHWC layouts the 2D view aliases the same bytes exactly. See max/kernels/test/gpu/nn/test_amd_4wave_conv*.mojo for end-to-end correctness coverage.

Functions​