IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

dispatch

AMD MI355X (CDNA4, gfx950) conv2d dispatch to amd_4wave_conv.

Returns True when the input/filter/output shapes + runtime stride / pad / dilation are handled by the 4-wave conv kernel; False to fall back to the caller's MIOpen path. Acceptance rules:

  • Hardware: MI355X (gfx950) only β€” the 4-wave kernel inherits the chiplet/L2 swizzle and MFMA shapes specific to CDNA4.
  • Input dtype: float8_e4m3fn, bfloat16, or float16. Output dtype: bfloat16 for FP8, otherwise tracks input.
  • All input / filter / output spatial shapes must be static (TileTensor static_shape[i] >= 0). Dynamic-shape conv shapes fall through to MIOpen.
  • num_groups == 1 (4-wave conv is single-group).
  • dilation == (1, 1).
  • stride ∈ {(1, 1), (2, 2)}, with stride[0] == stride[1].
  • symmetric_padding ∈ {(0, 0), (1, 1), (2, 2)}, square pad only.

When accepted, the dispatcher:

  1. Allocates a K-padded FRSC filter buffer (zero-filled trailing K columns when R*S*C_in isn't a multiple of 2*BK = 256).
  2. Transposes the caller's filter (FCRS or RSCF) into FRSC.
  3. Comptime-materializes (stride, pad) and calls amd_4wave_conv with the appropriate kernel template parameters.

Mirrors the structure of nn.conv.gpu.amd.rdna.dispatch and nn.conv.gpu.nvidia.sm100.dispatch.

Functions​