Mojo struct
Conv2dFpropKernel
struct Conv2dFpropKernel[act_type: DType, filter_type: DType, out_type: DType, config: Conv2dConfig[act_type, filter_type, out_type], cluster_shape: StaticTuple[Int32, 3] = StaticTuple(Int32(1)), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True]
SM100 Conv2D forward propagation kernel.
This kernel implements conv2d fprop using implicit GEMM with warp specialization. It reuses the matmul kernel architecture but with convolution-specific address calculation.
The kernel structure:
- Scheduler warp: CLC-based tile scheduling
- Load warp: TMA loads with im2col transformation
- MMA warp: Tensor core operations
- Epilogue warps: Output from TMEM to GMEM
Parametersβ
- βact_type (
DType): Activation data type. - βfilter_type (
DType): Filter data type. - βout_type (
DType): Output data type. - βconfig (
Conv2dConfig[act_type, filter_type, out_type]): Kernel configuration. - βcluster_shape (
StaticTuple[Int32, 3]): CUDA cluster dimensions. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional void epilogue lambda applied after output write. Signature:def(IndexList[2], SIMD) -> None. - βelementwise_compute_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]): Optional epilogue lambda for fusion (bias add, activation functions, residual connections). - βregister_based_epilogue (
Bool): Whether to apply the lambda in registers.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
accum_layoutβ
comptime accum_layout = Layout.row_major(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N)
accum_pipeline_consumer_arv_countβ
comptime accum_pipeline_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)
accum_pipeline_producer_arv_countβ
comptime accum_pipeline_producer_arv_count = 1
accum_typeβ
comptime accum_type = Conv2dConfig.accum_type()
AccumTensorβ
comptime AccumTensor = TmemTensor[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
act_expected_bytesβ
comptime act_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.act_smem_elements * size_of[act_type]())
act_swizzle_elemsβ
comptime act_swizzle_elems = (config.a_swizzle.bytes() // size_of[act_type]())
act_tile_dim0β
comptime act_tile_dim0 = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BM // Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)
act_tma_load_sizeβ
comptime act_tma_load_size = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].act_tile_dim0 * Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].act_swizzle_elems)
act_tma_rowsβ
comptime act_tma_rows = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].act_tile_dim0
ActDescLayoutβ
comptime ActDescLayout = Layout[*?, *?]
ActTileLayoutβ
comptime ActTileLayout = Layout[*?, *?]
ActTileLoaderTypeIm2colβ
comptime ActTileLoaderTypeIm2col = TileLoaderTMAIm2col[?, ?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
ActTmaOpβ
comptime ActTmaOp = TMATensorTileIm2col[act_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
BKβ
comptime BK = config.block_tile_shape[2]
BMβ
comptime BM = config.block_tile_shape[0]
BNβ
comptime BN = config.block_tile_shape[1]
clc_consumer_arv_countβ
comptime clc_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS + (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE * (((Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)))
clc_producer_arv_countβ
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_countβ
comptime clc_throttle_consumer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS
clc_throttle_producer_arv_countβ
comptime clc_throttle_producer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS
CLUSTER_Mβ
comptime CLUSTER_M = config.cluster_shape[0]
CLUSTER_Nβ
comptime CLUSTER_N = config.cluster_shape[1]
CLUSTER_SIZEβ
comptime CLUSTER_SIZE = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M * Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)
Contextβ
comptime Context = KernelContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_clc_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N]
cta_groupβ
comptime cta_group = config.cta_group
epi_load_consumer_arv_countβ
comptime epi_load_consumer_arv_count = SIMD(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)
epi_load_producer_arv_countβ
comptime epi_load_producer_arv_count = Int32(1)
EpiLoadPipelineTypeβ
comptime EpiLoadPipelineType = EpiLoadPipeline[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_epi_load_stages]
EPILOGUE_LOAD_THREADSβ
comptime EPILOGUE_LOAD_THREADS = WARP_SIZE
EPILOGUE_THREADSβ
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtxβ
comptime EpilogueCtx = EpilogueWarpContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]
filter_expected_bytesβ
comptime filter_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.filter_smem_elements * size_of[filter_type]())
filter_swizzle_elemsβ
comptime filter_swizzle_elems = (config.b_swizzle.bytes() // size_of[filter_type]())
filter_tile_dim0β
comptime filter_tile_dim0 = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BN // (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M // Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group))
filter_tma_load_sizeβ
comptime filter_tma_load_size = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].filter_tile_dim0 * Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].filter_swizzle_elems)
filter_tma_rowsβ
comptime filter_tma_rows = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].filter_tile_dim0
FilterDescLayoutβ
comptime FilterDescLayout = Layout[*?, *?]
FilterTileLayoutβ
comptime FilterTileLayout = Layout[*?, *?]
FilterTileLoaderTypeβ
comptime FilterTileLoaderType = TileLoaderTMA[?, ?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
FilterTmaOpβ
comptime FilterTmaOp = TMATensorTile[filter_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
input_expected_bytesβ
comptime input_expected_bytes = ((Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].act_expected_bytes + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].filter_expected_bytes)) * config)
InputTilePipelineTypeβ
comptime InputTilePipelineType = InputTilePipeline[StandardTilePayload[act_type, filter_type, IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BM, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BN, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages], Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size]
MMA_Kβ
comptime MMA_K = config.mma_shape[2]
MMA_Mβ
comptime MMA_M = config.mma_shape[0]
MMA_Nβ
comptime MMA_N = config.mma_shape[1]
MMA_THREADSβ
comptime MMA_THREADS = WARP_SIZE
MmaCtxβ
comptime MmaCtx = MmaWarpContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]
MmaEpilogueSyncβ
comptime MmaEpilogueSync = WarpGroupBarrier[(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS), 1]
MmaOpβ
comptime MmaOp = MmaOpSM100_SS[out_type, act_type, filter_type, config.block_tile_shape, config.mma_shape, accum_type=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=True]
num_accum_pipeline_stagesβ
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stagesβ
comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages
num_epi_load_stagesβ
comptime num_epi_load_stages = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_epi_load_stages
num_group_pipeline_stagesβ
comptime num_group_pipeline_stages = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages // config)
num_output_stagesβ
comptime num_output_stages = config.num_output_stages
num_output_warpsβ
comptime num_output_warps = 4
num_pipeline_stagesβ
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADSβ
comptime NUM_THREADS = ((((Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)
NUM_TMEM_COLSβ
comptime NUM_TMEM_COLS = 512
opcβ
comptime opc = OutputPipelineConfig(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group)
out_swizzle_elemsβ
comptime out_swizzle_elems = (config.c_swizzle.bytes() // size_of[out_type]())
out_tile_dim0β
comptime out_tile_dim0 = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutputM if (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256) if (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256) else (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group == 1) else 64
OutDescLayoutβ
comptime OutDescLayout = Layout[*?, *?]
OutputMβ
comptime OutputM = config.output_tile_shape[0]
OutputNβ
comptime OutputN = config.output_tile_shape[1]
OutputPipelineβ
comptime OutputPipeline = OutputTilePipeline[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc]
OutTileLayoutβ
comptime OutTileLayout = Layout[*?, *?]
OutTmaOpβ
comptime OutTmaOp = TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
Schedulerβ
comptime Scheduler = TileScheduler[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[0], config.cluster_shape[1], config.cluster_shape[2]), block_swizzle_size=config.block_swizzle_size]
SCHEDULER_THREADSβ
comptime SCHEDULER_THREADS = WARP_SIZE
SmemTypeβ
comptime SmemType = Conv2dSmem[act_type, filter_type, out_type, config=config]
src_expected_bytesβ
comptime src_expected_bytes = ((Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutputM * Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutputN) * size_of[out_type]())
SrcCTileArrayβ
comptime SrcCTileArray = SMemTileArray2DRowMajor[out_type, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_output_stages]
SrcDescLayoutβ
comptime SrcDescLayout = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutDescLayout
SrcTileLayoutβ
comptime SrcTileLayout = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutTileLayout
SrcTileLoaderTypeβ
comptime SrcTileLoaderType = TileLoaderTMA[?, ?, ?, ?, ?, cta_group=1]
SrcTmaOpβ
comptime SrcTmaOp = TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
stage_stride_colsβ
comptime stage_stride_cols = (512 // Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages)
TilePayloadβ
comptime TilePayload = StandardTilePayload[act_type, filter_type, IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BM, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BN, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages]
TileWriterTypeβ
comptime TileWriterType = TileWriter[act_type, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, config.block_tile_shape, config.mma_shape, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, config.c_swizzle, False, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_output_stages, 4, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue]
TMA_LOAD_THREADSβ
comptime TMA_LOAD_THREADS = WARP_SIZE
Tmemβ
comptime Tmem = TmemAllocation[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc.cta_group]
TmemDeallocβ
comptime TmemDealloc = TmemDeallocBarrier[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc.cta_group]
Methodsβ
mmaβ
static mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc], tiles: ConsumerTiles[tiles_origin, StandardTilePayload[act_type, filter_type, IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BM, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BN, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages], Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[accum_type=mma_op.accum_type, cta_group=mma_op.cta_group, cluster_shape=mma_op.cluster_shape, a_swizzle=mma_op.a_swizzle, b_swizzle=mma_op.b_swizzle, transpose_b=mma_op.transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)
Execute MMA operations for one pipeline stage.
init_barriersβ
static init_barriers(ctx: KernelContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_clc_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N], act_tma_op: TMATensorTileIm2col[act_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], input_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages * 2)], clc_full: SMemArray[SharedMemBarrier, Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1], epi_load_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_epi_load_stages * 2)], load_order_barrier: SMemArray[SharedMemBarrier, 1])
Initialize barriers and prefetch TMA descriptors.
load_input_tilesβ
static load_input_tiles[act_tma_origin: ImmutOrigin, filter_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](act_loader: TileLoaderTMAIm2col[act_tma_origin, act_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]](), cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], filter_loader: TileLoaderTMA[filter_tma_origin, filter_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]](), cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], tiles: ProducerTiles[tiles_origin, StandardTilePayload[act_type, filter_type, IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BM, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), IndexList(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BN, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].BK, __list_literal__=NoneType(None)), Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages], Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], iter_idx: UInt32, work_m_coord: Int, work_n_coord: Int, peer_cta_coord: Tuple[Int, Int, Int], elect_one_cta: Bool)
Load activation (via im2col TMA) and filter tiles.
The im2col TMA descriptor handles coordinate transformation internally. Coordinates are in GEMM space:
- work_m_coord: M coordinate (batch * H_out * W_out)
- work_n_coord: N coordinate (output channels)
- iter_idx: K dimension tile index (C * R * S)
runβ
static run(act_tma_op: TMATensorTileIm2col[act_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3])
Kernel entry point for Conv2D fprop (no residual).
Args:
- βact_tma_op (
TMATensorTileIm2col[act_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): Im2col TMA descriptor for activation. - βfilter_tma_op (
TMATensorTile[filter_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for filter. - βout_tma_op (
TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for output. - βcluster_dim (
StaticTuple[Int32, 3]): Cluster dimensions. - βmnk (
StaticTuple[UInt32, 3]): GEMM dimensions (M, N, K).
run_with_residualβ
static run_with_residual(act_tma_op: TMATensorTileIm2col[act_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], src_tma_op: TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], beta: Float32)
Kernel entry point for Conv2D fprop with residual (D = Conv + beta*C).
Args:
- βact_tma_op (
TMATensorTileIm2col[act_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): Im2col TMA descriptor for activation. - βfilter_tma_op (
TMATensorTile[filter_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for filter. - βout_tma_op (
TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for output D. - βsrc_tma_op (
TMATensorTile[out_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for source C (residual input). - βcluster_dim (
StaticTuple[Int32, 3]): Cluster dimensions. - βmnk (
StaticTuple[UInt32, 3]): GEMM dimensions (M, N, K). - βbeta (
Float32): Residual scale factor.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!