For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
dispatch_3d
AMD MI355X (CDNA4, gfx950) conv3d dispatch to amd_4wave_conv.
Sibling of nn.conv.gpu.amd.dispatch.dispatch_amd_4wave_conv2d, extended
to 3D NDHWC inputs via the loader's Q > 1 mode. Single-kernel
implicit-GEMM: no M*K im2col scratch, fp32 accumulator inside the
MMA, no per-q Q-slice round-trips.
Returns True when the input/filter/output shapes + runtime stride / pad / dilation are handled by the 4-wave conv kernel; False to fall back to the caller's im2col / cuDNN path. Acceptance rules:
- Hardware: MI355X (gfx950) only.
- Input dtype: float8_e4m3fn, bfloat16, or float16. Output dtype: bfloat16 for FP8, otherwise tracks input.
- All input/filter/output spatial shapes must be static
(TileTensor
static_shape[i] >= 0). Dynamic shapes fall through. num_groups == 1.dilation == (1, 1, 1).stride[0] == stride[1] == stride[2]and stride β {1, 2}.symmetric_padding[i] β {0, 1, 2}for each axis, with the further constraint thatpad_h == pad_w(the kernel takes a single (stride, pad) tuple) andpad_dmay differ β but for now we require all three pads equal so the static-launch enumeration stays bounded.
When accepted, the dispatcher:
- Allocates a K-padded [F, K_padded] filter buffer (K_padded = round_up(QRSC_in, 2BK = 256), zero-filling the trailing K columns).
- Transposes the caller's filter (QRSCF or FCQRS) into [F, K_padded] row-major.
- Comptime-materializes (stride, pad) and calls
amd_4wave_convwith Q > 1 (3D mode).
Mirrors the structure of nn.conv.gpu.amd.dispatch.
Functionsβ
- β
dispatch_amd_4wave_conv3d: Try to dispatch a Conv3D toamd_4wave_convon MI355X. Returns True if handled; False if the caller should fall through (typically todispatch_im2col_matmul_conv3d).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!