Skip to main content

Mojo module

qslice_conv3d

Q-slice 3-D conv → Q sequential SM100 2-D conv calls with fp32 accumulator.

A 3D conv with stride=1, dilation=1 decomposes exactly as:

output[n, d_o, h_o, w_o, f]
  = Σ_q  [ Σ_{r,s,c}  input[n, d_o+q, h_o+r-pad_h, w_o+s-pad_w, c]
                      * filter[q, r, s, c, f] ]

The inner Σ_{r,s,c} is a 2-D conv on a single depth slice, so by invoking the existing dispatch_sm100_conv2d Q times we inherit SM100 UMMA performance without forking any TMA code.

ACCUMULATION STRATEGY

Each SM100 2-D conv produces a bf16 result (the kernel uses an fp32 accumulator internally but stores bf16). A direct attempt to chain Q convs via has_residual=True / beta=1.0 round-trips the running sum through bf16 every step, which compounds per-element rounding far beyond what a single 3-D conv produces (empirically ~100-1000 diff vs the reference on WAN mid_res shapes).

This dispatcher instead maintains a dedicated fp32 accumulator buffer outside the conv calls:

  1. Allocate accum_fp32 of output size and zero-fill it.
  2. Allocate one reusable temp_bf16 buffer of output size.
  3. For q in [0, Q):
    • Call dispatch_sm100_conv2d(has_residual=False) to write conv(input[q], filter[q]) into temp_bf16.
    • Launch an elementwise kernel that reads temp_bf16, casts to fp32, and adds into accum_fp32.
  4. Launch a final kernel that casts accum_fp32 → bf16 and writes to user output (fused with the caller's 5-D epilogue when one is provided).

Only the final cast is lossy; all intermediate accumulation is fp32, matching what a single all-at-once 3-D conv would produce internally.

Gate:

  • bf16 input/filter/output dtype.
  • SM100 device (_is_sm10x_gpu).
  • filter_is_fcrs=False (QRSCF only — the per-q slab is a contiguous RSCF view at offset q*R*S*C*F; FCQRS would need a separate extraction kernel because a fixed-q FCQRS slice is non-contiguous).
  • stride=1, dilation=1, groups=1, Q>1.
  • Zero temporal padding (WAN's causal 3D convs pre-pad temporal externally).
  • C_in % 64 == 0 and C_out % 64 == 0 (SM100 alignment).

When C_out is 64-aligned but not 128-aligned (e.g., C_out=192), the dispatcher zero-pads the filter's F axis up to the next multiple of 128 (once, before the q loop), runs the accumulator at C_out_padded, and strides the final fp32→bf16 cast to drop the padded columns when writing user output. This costs ~33% extra compute on the 192→192 shape but keeps the MMA at its native 128-wide N tile.

Declined shapes fall through to dispatch_im2col_matmul_conv3d.

Functions

Was this page helpful?