Mojo module
conv2d_fprop_kernel
SM100 Conv2D Forward Propagation Kernel.
This module implements a high-performance Conv2D fprop kernel for NVIDIA Blackwell (SM100) GPUs using the Structured Kernel architecture.
The kernel uses implicit GEMM to compute convolution:
- Maps conv to GEMM: C[M,N] = A[M,K] @ B[K,N]
- M = batch * out_h * out_w (output spatial)
- N = out_channels (filters)
- K = in_channels * filter_h * filter_w (reduction)
The implementation reuses matmul infrastructure:
- 8-warp specialization (scheduler, load, MMA, epilogue load, epilogue)
- TMA-based tile loading with im2col addressing
- TMEM accumulators
- Producer-consumer pipelining
Supported configurations (Flux VAE optimized):
- stride=1, dilation=1 (most common in VAE decoder)
- 3x3 and 1x1 kernels
- BF16/FP16 data types
Structs
-
Conv2dFpropKernel: SM100 Conv2D forward propagation kernel.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!