Mojo module
matmul_1x1x1_conv3d
Direct _matmul_gpu dispatch for 1x1x1 3D convolutions.
A 1x1x1 conv (Q=R=S=1) with stride=1, dilation=1, groups=1, and zero padding is algebraically identical to a single matmul:
output[b, d, h, w, f] = Σ_c input[b, d, h, w, c] * filter[0, 0, 0, c, f]NDHWC input is already C-innermost contiguous, so we can view it as [M, C_in] with M = BDH*W on the same pointer. Filter FCQRS/QRSCF with Q=R=S=1 reduces to [F, C] or [C, F] respectively — no transpose kernel needed. Output NDHWC collapses to [M, F]. No scratch allocation is required, and the epilogue unflattens (m, f) -> (b, d, h, w, f) in one call to the 5D lambda.
Covers every 1x1x1 case in the WAN VAE (post_quant_conv, per-block
conv_shortcut). Used in conv_gpu's 5D arm as the first branch in
the QRSCF dispatch chain, before dispatch_im2col_matmul_conv3d.
Functions
-
dispatch_1x1x1_matmul_conv3d: Try to dispatch a 1x1x1 3D conv directly as a single_matmul_gpu.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!