IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

dispatch_3d

AMD MI355X (CDNA4, gfx950) conv3d dispatch to amd_4wave_conv.

Sibling of nn.conv.gpu.amd.dispatch.dispatch_amd_4wave_conv2d, extended to 3D NDHWC inputs via the loader's Q > 1 mode. Single-kernel implicit-GEMM: no M*K im2col scratch, fp32 accumulator inside the MMA, no per-q Q-slice round-trips.

Returns True when the input/filter/output shapes + runtime stride / pad / dilation are handled by the 4-wave conv kernel; False to fall back to the caller's im2col / cuDNN path. Acceptance rules:

  • Hardware: MI355X (gfx950) only.
  • Input dtype: float8_e4m3fn, bfloat16, or float16. Output dtype: bfloat16 for FP8, otherwise tracks input.
  • All input/filter/output spatial shapes must be static (TileTensor static_shape[i] >= 0). Dynamic shapes fall through.
  • num_groups == 1.
  • dilation == (1, 1, 1).
  • stride[0] == stride[1] == stride[2] and stride ∈ {1, 2}.
  • symmetric_padding[i] ∈ {0, 1, 2} for each axis, with the further constraint that pad_h == pad_w (the kernel takes a single (stride, pad) tuple) and pad_d may differ β€” but for now we require all three pads equal so the static-launch enumeration stays bounded.

When accepted, the dispatcher:

  1. Allocates a K-padded [F, K_padded] filter buffer (K_padded = round_up(QRSC_in, 2BK = 256), zero-filling the trailing K columns).
  2. Transposes the caller's filter (QRSCF or FCQRS) into [F, K_padded] row-major.
  3. Comptime-materializes (stride, pad) and calls amd_4wave_conv with Q > 1 (3D mode).

Mirrors the structure of nn.conv.gpu.amd.dispatch.

Functions​

  • ​dispatch_amd_4wave_conv3d: Try to dispatch a Conv3D to amd_4wave_conv on MI355X. Returns True if handled; False if the caller should fall through (typically to dispatch_im2col_matmul_conv3d).