Mojo module
dispatch
AMD MI355X (CDNA4, gfx950) conv2d dispatch to amd_4wave_conv.
Returns True when the input/filter/output shapes + runtime stride / pad / dilation are handled by the 4-wave conv kernel; False to fall back to the caller's MIOpen path. Acceptance rules:
- Hardware: MI355X (gfx950) only β the 4-wave kernel inherits the chiplet/L2 swizzle and MFMA shapes specific to CDNA4.
- Input dtype: float8_e4m3fn, bfloat16, or float16. Output dtype: bfloat16 for FP8, otherwise tracks input.
- All input / filter / output spatial shapes must be static
(TileTensor
static_shape[i] >= 0). Dynamic-shape conv shapes fall through to MIOpen. num_groups == 1(4-wave conv is single-group).dilation == (1, 1).stride β {(1, 1), (2, 2)}, withstride[0] == stride[1].symmetric_padding β {(0, 0), (1, 1), (2, 2)}, square pad only.
When accepted, the dispatcher:
- Allocates a K-padded FRSC filter buffer (zero-filled trailing
K columns when
R*S*C_inisn't a multiple of2*BK = 256). - Transposes the caller's filter (FCRS or RSCF) into FRSC.
- Comptime-materializes (stride, pad) and calls
amd_4wave_convwith the appropriate kernel template parameters.
Mirrors the structure of nn.conv.gpu.amd.rdna.dispatch and
nn.conv.gpu.nvidia.sm100.dispatch.
Functionsβ
- β
dispatch_amd_4wave_conv2d: Try to dispatch a Conv2D toamd_4wave_convon MI355X.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!