IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

qslice_conv3d

Q-slice 3-D conv β†’ Q sequential SM100 2-D conv calls with fp32 accumulator.

A 3D conv with stride=1, dilation=1 decomposes exactly as:

output[n, d_o, h_o, w_o, f]
  = Ξ£_q  [ Ξ£_{r,s,c}  input[n, d_o+q, h_o+r-pad_h, w_o+s-pad_w, c]
                      * filter[q, r, s, c, f] ]

The inner Ξ£_{r,s,c} is a 2-D conv on a single depth slice, so by invoking the existing dispatch_sm100_conv2d Q times we inherit SM100 UMMA performance without forking any TMA code.

ACCUMULATION STRATEGY​

Each SM100 2-D conv produces a bf16 result (the kernel uses an fp32 accumulator internally but stores bf16). A direct attempt to chain Q convs via has_residual=True / beta=1.0 round-trips the running sum through bf16 every step, which compounds per-element rounding far beyond what a single 3-D conv produces (empirically ~100-1000 diff vs the reference on WAN mid_res shapes).

This dispatcher instead maintains a dedicated fp32 accumulator buffer outside the conv calls:

  1. Allocate accum_fp32 of output size and zero-fill it.
  2. Allocate one reusable temp_bf16 buffer of output size.
  3. For q in [0, Q):
    • Call dispatch_sm100_conv2d(has_residual=False) to write conv(input[q], filter[q]) into temp_bf16.
    • Launch an elementwise kernel that reads temp_bf16, casts to fp32, and adds into accum_fp32.
  4. Launch a final kernel that casts accum_fp32 β†’ bf16 and writes to user output (fused with the caller's 5-D epilogue when one is provided).

Only the final cast is lossy; all intermediate accumulation is fp32, matching what a single all-at-once 3-D conv would produce internally.

Gate:

  • bf16 input/filter/output dtype.
  • SM100 device (_is_sm10x_gpu).
  • filter_is_fcrs=False (QRSCF only β€” the per-q slab is a contiguous RSCF view at offset q*R*S*C*F; FCQRS would need a separate extraction kernel because a fixed-q FCQRS slice is non-contiguous).
  • stride=1, dilation=1, groups=1, Q>1.
  • Zero temporal padding (WAN's causal 3D convs pre-pad temporal externally).
  • C_in % 64 == 0 and C_out % 64 == 0 (SM100 alignment).

Declined shapes fall through to dispatch_im2col_matmul_conv3d.

Functions​