Mojo struct
Conv2dFpropKernel
struct Conv2dFpropKernel[act_type: DType, filter_type: DType, out_type: DType, act_layout: Layout, filter_layout: Layout, out_layout: Layout, act_desc_layout: Layout, filter_desc_layout: Layout, out_desc_layout: Layout, config: Conv2dConfig[act_type, filter_type, out_type], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, src_layout: Layout = out_layout, src_desc_layout: Layout = out_desc_layout]
SM100 Conv2D forward propagation kernel.
This kernel implements conv2d fprop using implicit GEMM with warp specialization. It reuses the matmul kernel architecture but with convolution-specific address calculation.
The kernel structure:
- Scheduler warp: CLC-based tile scheduling
- Load warp: TMA loads with im2col transformation
- MMA warp: Tensor core operations
- Epilogue warps: Output from TMEM to GMEM
Parameters
- act_type (
DType): Activation data type. - filter_type (
DType): Filter data type. - out_type (
DType): Output data type. - act_layout (
Layout): Global memory activation layout. - filter_layout (
Layout): Global memory filter layout. - out_layout (
Layout): Global memory output layout. - act_desc_layout (
Layout): TMA descriptor layout for activation. - filter_desc_layout (
Layout): TMA descriptor layout for filter. - out_desc_layout (
Layout): TMA descriptor layout for output. - config (
Conv2dConfig): Kernel configuration. - cluster_shape (
StaticTuple): CUDA cluster dimensions. - elementwise_compute_lambda_fn (
Optional): Optional epilogue lambda for fusion (bias add, activation functions, residual connections). - register_based_epilogue (
Bool): Whether to apply the lambda in registers. - src_layout (
Layout): Global memory layout for source C (residual input). - src_desc_layout (
Layout): TMA descriptor layout for source C.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
accum_layout
comptime accum_layout = Layout.row_major(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_N)
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = 1
accum_type
comptime accum_type = Conv2dConfig.accum_type[act_type, filter_type, out_type]()
AccumTensor
comptime AccumTensor = TmemTensor[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]
act_expected_bytes
comptime act_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.act_smem_layout.size() * size_of[act_type]())
act_tma_load_size
comptime act_tma_load_size = act_desc_layout.size()
act_tma_rows
comptime act_tma_rows = act_desc_layout.shape[0].value()
ActTileLoaderTypeIm2col
comptime ActTileLoaderTypeIm2col = TileLoaderTMAIm2col[?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]
BK
comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)
c_smem_layout
comptime c_smem_layout = Layout.row_major(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputM, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputN)
clc_consumer_arv_count
comptime clc_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS + (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_SIZE * (((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)))
clc_producer_arv_count
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_count
comptime clc_throttle_consumer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS
clc_throttle_producer_arv_count
comptime clc_throttle_producer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS
CLUSTER_M
comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)
CLUSTER_N
comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)
CLUSTER_SIZE
comptime CLUSTER_SIZE = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N)
Context
comptime Context = KernelContext[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_clc_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N]
cta_group
comptime cta_group = config.cta_group
epi_load_consumer_arv_count
comptime epi_load_consumer_arv_count = SIMD[DType.int32, 1](Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)
epi_load_producer_arv_count
comptime epi_load_producer_arv_count = 1
EpiLoadPipelineType
comptime EpiLoadPipelineType = EpiLoadPipeline[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_epi_load_stages]
EPILOGUE_LOAD_THREADS
comptime EPILOGUE_LOAD_THREADS = WARP_SIZE
EPILOGUE_THREADS
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS]
filter_expected_bytes
comptime filter_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.filter_smem_layout.size() * size_of[filter_type]())
filter_tma_load_size
comptime filter_tma_load_size = filter_desc_layout.size()
filter_tma_rows
comptime filter_tma_rows = filter_desc_layout.shape[0].value()
FilterTileLoaderType
comptime FilterTileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]
input_expected_bytes
comptime input_expected_bytes = ((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group * (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].act_expected_bytes + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].filter_expected_bytes)) * config)
InputTilePipelineType
comptime InputTilePipelineType = InputTilePipeline[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TilePayload, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_group_pipeline_stages, config.k_group_size]
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
MMA_THREADS
comptime MMA_THREADS = WARP_SIZE
MmaCtx
comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS]
MmaEpilogueSync
comptime MmaEpilogueSync = WarpGroupBarrier[(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS), 1]
MmaOp
comptime MmaOp = MmaOpSM100_SS[out_type, act_type, filter_type, config.block_tile_shape, config.mma_shape, accum_type=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_type, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=True]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stages
comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages
num_epi_load_stages
comptime num_epi_load_stages = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_epi_load_stages
num_group_pipeline_stages
comptime num_group_pipeline_stages = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_pipeline_stages // config)
num_output_stages
comptime num_output_stages = config.num_output_stages
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADS
comptime NUM_THREADS = ((((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]
Scheduler
comptime Scheduler = TileScheduler[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), block_swizzle_size=config.block_swizzle_size]
SCHEDULER_THREADS
comptime SCHEDULER_THREADS = WARP_SIZE
SmemType
comptime SmemType = Conv2dSmem[act_type, filter_type, out_type, config=config]
src_expected_bytes
comptime src_expected_bytes = ((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].OutputM * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].OutputN) * size_of[out_type]())
SrcCTileArray
comptime SrcCTileArray = SMemTileArray[out_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].c_smem_layout, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_output_stages, 128]
SrcTileLoaderType
comptime SrcTileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=1]
stage_stride_cols
comptime stage_stride_cols = (512 // Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_accum_pipeline_stages)
TilePayload
comptime TilePayload = StandardTilePayload[act_type, filter_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BM, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BK, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BN, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BK, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_pipeline_stages]
TileWriterType
comptime TileWriterType = TileWriter[act_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_type, config.block_tile_shape, config.mma_shape, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, False, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputM, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputN, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_output_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, 4, elementwise_compute_lambda_fn, register_based_epilogue]
TMA_LOAD_THREADS
comptime TMA_LOAD_THREADS = WARP_SIZE
Tmem
comptime Tmem = TmemAllocation[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]
TmemDealloc
comptime TmemDealloc = TmemDeallocBarrier[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]
Methods
mma
static mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group], tiles: InputConsumerStage[tiles_origin, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TilePayload, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)
Execute MMA operations for one pipeline stage.
init_barriers
static init_barriers(ctx: KernelContext[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_clc_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N], act_tma_op: TMATensorTileIm2col[act_type, act_layout, act_desc_layout], filter_tma_op: TMATensorTile[filter_type, filter_layout, filter_desc_layout], out_tma_op: TMATensorTile[out_type, out_layout, out_desc_layout], input_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages * 2)], clc_full: SMemArray[SharedMemBarrier, Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1], epi_load_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_epi_load_stages * 2)], load_order_barrier: SMemArray[SharedMemBarrier, 1])
Initialize barriers and prefetch TMA descriptors.
load_input_tiles
static load_input_tiles[act_tma_origin: ImmutOrigin, filter_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](act_loader: TileLoaderTMAIm2col[act_tma_origin, act_type, act_layout, act_desc_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group], filter_loader: TileLoaderTMA[filter_tma_origin, filter_type, filter_layout, filter_desc_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group], tiles: InputProducerStage[tiles_origin, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TilePayload, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_group_pipeline_stages, config.k_group_size], iter_idx: UInt32, work_m_coord: Scalar[DType.uint], work_n_coord: Scalar[DType.uint], peer_cta_coord: Tuple[UInt, UInt, UInt], elect_one_cta: Bool)
Load activation (via im2col TMA) and filter tiles.
The im2col TMA descriptor handles coordinate transformation internally. Coordinates are in GEMM space:
- work_m_coord: M coordinate (batch * H_out * W_out)
- work_n_coord: N coordinate (output channels)
- iter_idx: K dimension tile index (C * R * S)
run
static run(act_tma_op: TMATensorTileIm2col[act_type, act_layout, act_desc_layout], filter_tma_op: TMATensorTile[filter_type, filter_layout, filter_desc_layout], out_tma_op: TMATensorTile[out_type, out_layout, out_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3])
Kernel entry point for Conv2D fprop (no residual).
Args:
- act_tma_op (
TMATensorTileIm2col): Im2col TMA descriptor for activation. - filter_tma_op (
TMATensorTile): TMA descriptor for filter. - out_tma_op (
TMATensorTile): TMA descriptor for output. - cluster_dim (
StaticTuple): Cluster dimensions. - mnk (
StaticTuple): GEMM dimensions (M, N, K).
run_with_residual
static run_with_residual(act_tma_op: TMATensorTileIm2col[act_type, act_layout, act_desc_layout], filter_tma_op: TMATensorTile[filter_type, filter_layout, filter_desc_layout], out_tma_op: TMATensorTile[out_type, out_layout, out_desc_layout], src_tma_op: TMATensorTile[out_type, src_layout, src_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], beta: Float32)
Kernel entry point for Conv2D fprop with residual (D = Conv + beta*C).
Args:
- act_tma_op (
TMATensorTileIm2col): Im2col TMA descriptor for activation. - filter_tma_op (
TMATensorTile): TMA descriptor for filter. - out_tma_op (
TMATensorTile): TMA descriptor for output D. - src_tma_op (
TMATensorTile): TMA descriptor for source C (residual input). - cluster_dim (
StaticTuple): Cluster dimensions. - mnk (
StaticTuple): GEMM dimensions (M, N, K). - beta (
Float32): Residual scale factor.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!