Mojo module
qslice_conv3d
Q-slice 3-D conv → Q sequential SM100 2-D conv calls with fp32 accumulator.
A 3D conv with stride=1, dilation=1 decomposes exactly as:
output[n, d_o, h_o, w_o, f]
= Σ_q [ Σ_{r,s,c} input[n, d_o+q, h_o+r-pad_h, w_o+s-pad_w, c]
* filter[q, r, s, c, f] ]The inner Σ_{r,s,c} is a 2-D conv on a single depth slice, so by
invoking the existing dispatch_sm100_conv2d Q times we inherit SM100
UMMA performance without forking any TMA code.
ACCUMULATION STRATEGY
Each SM100 2-D conv produces a bf16 result (the kernel uses an fp32
accumulator internally but stores bf16). A direct attempt to chain
Q convs via has_residual=True / beta=1.0 round-trips the running
sum through bf16 every step, which compounds per-element rounding
far beyond what a single 3-D conv produces (empirically ~100-1000
diff vs the reference on WAN mid_res shapes).
This dispatcher instead maintains a dedicated fp32 accumulator buffer outside the conv calls:
- Allocate
accum_fp32of output size and zero-fill it. - Allocate one reusable
temp_bf16buffer of output size. - For q in [0, Q):
- Call
dispatch_sm100_conv2d(has_residual=False)to writeconv(input[q], filter[q])intotemp_bf16. - Launch an elementwise kernel that reads
temp_bf16, casts to fp32, and adds intoaccum_fp32.
- Call
- Launch a final kernel that casts
accum_fp32→ bf16 and writes to useroutput(fused with the caller's 5-D epilogue when one is provided).
Only the final cast is lossy; all intermediate accumulation is fp32, matching what a single all-at-once 3-D conv would produce internally.
Gate:
- bf16 input/filter/output dtype.
- SM100 device (
_is_sm10x_gpu). filter_is_fcrs=False(QRSCF only — the per-q slab is a contiguous RSCF view at offsetq*R*S*C*F; FCQRS would need a separate extraction kernel because a fixed-q FCQRS slice is non-contiguous).- stride=1, dilation=1, groups=1, Q>1.
- Zero temporal padding (WAN's causal 3D convs pre-pad temporal externally).
C_in % 64 == 0andC_out % 64 == 0(SM100 alignment).
When C_out is 64-aligned but not 128-aligned (e.g., C_out=192),
the dispatcher zero-pads the filter's F axis up to the next
multiple of 128 (once, before the q loop), runs the accumulator
at C_out_padded, and strides the final fp32→bf16 cast to drop
the padded columns when writing user output. This costs ~33%
extra compute on the 192→192 shape but keeps the MMA at its
native 128-wide N tile.
Declined shapes fall through to dispatch_im2col_matmul_conv3d.
Functions
-
dispatch_qslice_conv3d_sm100: Try to dispatch a 3-D conv as Q × SM100 2-D conv calls with a dedicated fp32 accumulator.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!