Skip to main content

Mojo module

dispatch

AMD MI355X (CDNA4, gfx950) conv2d dispatch to amd_4wave_conv.

Returns True when the input/filter/output shapes + runtime stride / pad / dilation are handled by the 4-wave conv kernel; False to fall back to the caller's MIOpen path. Acceptance rules:

  • Hardware: MI355X (gfx950) only β€” the 4-wave kernel inherits the chiplet/L2 swizzle and MFMA shapes specific to CDNA4.
  • Input dtype: float8_e4m3fn, bfloat16, or float16. Output dtype: bfloat16 for FP8, otherwise tracks input.
  • All input / filter / output spatial shapes must be static (TileTensor static_shape[i] >= 0). Dynamic-shape conv shapes fall through to MIOpen.
  • num_groups == 1 (4-wave conv is single-group).
  • dilation == (1, 1).
  • stride ∈ {(1, 1), (2, 2)}, with stride[0] == stride[1].
  • symmetric_padding ∈ {(0, 0), (1, 1), (2, 2)}, square pad only.

When accepted, the dispatcher:

  1. Allocates a K-padded FRSC filter buffer (zero-filled trailing K columns when R*S*C_in isn't a multiple of 2*BK = 256).
  2. Transposes the caller's filter (FCRS or RSCF) into FRSC.
  3. Comptime-materializes (stride, pad) and calls amd_4wave_conv with the appropriate kernel template parameters.

Mirrors the structure of nn.conv.gpu.amd.rdna.dispatch and nn.conv.gpu.nvidia.sm100.dispatch.

Functions​