IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

amd_4wave_conv_residual

AMD 4-wave Conv2D fprop with optional residual add.

Mirrors max/kernels/src/nn/conv/gpu/nvidia/sm100/conv2d.mojo's conv2d_fprop_with_residual API so the production-level dispatcher can swap SM100 (Blackwell) for AMD MI355X (CDNA4) without touching the call site. Same Conv2dProblemShape, same elementwise_lambda_fn and elementwise_compute_lambda_fn hooks, same D = Conv(A,B) + beta*C semantics.

The residual add fires inside AMD4WaveMatmul.run_conv2d's epilogue, which bulk-prefetches source into a per-lane VGPR cluster before the main loop so the HBM read latency overlaps with the MFMAs. No extra HBM round-trip, no separate elementwise launch.

When has_residual=False (or beta == 0.0), the call routes to amd_4wave_conv directly โ€” same code path, no residual cost.

Epilogue ordering matches SM100: D = lambda(Conv(A,B)) + beta * C, i.e. elementwise_compute_lambda_fn (pre-residual: bias / ReLU / SiLU / GELU) fires on the post-cast c_type MMA output before the residual FMA, and elementwise_lambda_fn (post-residual: void store-site lambda) fires after with the fused value.

Functionsโ€‹