Mojo module
amd_4wave_conv_residual
AMD 4-wave Conv2D fprop with optional residual add.
Mirrors max/kernels/src/nn/conv/gpu/nvidia/sm100/conv2d.mojo's
conv2d_fprop_with_residual API so the production-level dispatcher
can swap SM100 (Blackwell) for AMD MI355X (CDNA4) without touching
the call site. Same Conv2dProblemShape, same elementwise_lambda_fn
and elementwise_compute_lambda_fn hooks, same
D = Conv(A,B) + beta*C semantics.
The residual add fires inside AMD4WaveMatmul.run_conv2d's epilogue,
which bulk-prefetches source into a per-lane VGPR cluster before
the main loop so the HBM read latency overlaps with the MFMAs. No
extra HBM round-trip, no separate elementwise launch.
When has_residual=False (or beta == 0.0), the call routes to
amd_4wave_conv directly โ same code path, no residual cost.
Epilogue ordering matches SM100: D = lambda(Conv(A,B)) + beta * C,
i.e. elementwise_compute_lambda_fn (pre-residual: bias / ReLU /
SiLU / GELU) fires on the post-cast c_type MMA output before the
residual FMA, and elementwise_lambda_fn (post-residual: void
store-site lambda) fires after with the fused value.
Functionsโ
- โ
amd_4wave_conv_fprop_with_residual: Launch AMD 4-wave Conv2D fprop with optional residual add.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!