Skip to main content

Mojo module

amd_4wave_conv_residual

AMD 4-wave Conv2D fprop with optional residual add.

Mirrors max/kernels/src/nn/conv/gpu/nvidia/sm100/conv2d.mojo's conv2d_fprop_with_residual API so the production-level dispatcher can swap SM100 (Blackwell) for AMD MI355X (CDNA4) without touching the call site. Same Conv2dProblemShape, same elementwise_lambda_fn and elementwise_compute_lambda_fn hooks, same D = Conv(A,B) + beta*C semantics.

The residual add fires inside AMD4WaveMatmul.run_conv2d's epilogue, which bulk-prefetches source into a per-lane VGPR cluster before the main loop so the HBM read latency overlaps with the MFMAs. No extra HBM round-trip, no separate elementwise launch.

When has_residual=False (or beta == 0.0), the call routes to amd_4wave_conv directly โ€” same code path, no residual cost.

Epilogue ordering matches SM100: D = lambda(Conv(A,B)) + beta * C, i.e. elementwise_compute_lambda_fn (pre-residual: bias / ReLU / SiLU / GELU) fires on the post-cast c_type MMA output before the residual FMA, and elementwise_lambda_fn (post-residual: void store-site lambda) fires after with the fused value.

Functionsโ€‹