IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Conv2dFpropKernel

struct Conv2dFpropKernel[act_type: DType, filter_type: DType, out_type: DType, config: Conv2dConfig[act_type, filter_type, out_type], cluster_shape: StaticTuple[Int32, Int(3)] = StaticTuple(Int32(1)), elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True]

SM100 Conv2D forward propagation kernel.

This kernel implements conv2d fprop using implicit GEMM with warp specialization. It reuses the matmul kernel architecture but with convolution-specific address calculation.

The kernel structure:

  • Scheduler warp: CLC-based tile scheduling
  • Load warp: TMA loads with im2col transformation
  • MMA warp: Tensor core operations
  • Epilogue warps: Output from TMEM to GMEM

Parameters​

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

accum_layout​

comptime accum_layout = Layout.row_major(config.mma_shape[Int(0)], config.mma_shape[Int(1)])

accum_pipeline_consumer_arv_count​

comptime accum_pipeline_consumer_arv_count = (config * Int((mul _resolve_warp_size(), 4)))

accum_pipeline_producer_arv_count​

comptime accum_pipeline_producer_arv_count = 1

accum_type​

comptime accum_type = Conv2dConfig.accum_type()

AccumTensor​

comptime AccumTensor = TmemTensor[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_layout, cta_group=config.cta_group]

act_expected_bytes​

comptime act_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.act_smem_elements * size_of[act_type]())

act_swizzle_elems​

comptime act_swizzle_elems = (config.a_swizzle.bytes() // size_of[act_type]())

act_tile_dim0​

comptime act_tile_dim0 = (config.block_tile_shape[Int(0)] // config.cluster_shape[Int(1)])

act_tma_load_size​

comptime act_tma_load_size = ((config.block_tile_shape[Int(0)] // config.cluster_shape[Int(1)]) * (config.a_swizzle.bytes() // size_of[act_type]()))

act_tma_rows​

comptime act_tma_rows = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].act_tile_dim0

ActDescLayout​

comptime ActDescLayout = Layout[*?, *?]

ActTileLayout​

comptime ActTileLayout = Layout[*?, *?]

ActTileLoaderTypeIm2col​

comptime ActTileLoaderTypeIm2col = TileLoaderTMAIm2col[_, _, _, _, _, cta_group=config.cta_group]

ActTmaOp​

comptime ActTmaOp = TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

BK​

comptime BK = config.block_tile_shape[Int(2)]

BM​

comptime BM = config.block_tile_shape[Int(0)]

BN​

comptime BN = config.block_tile_shape[Int(1)]

clc_consumer_arv_count​

comptime clc_consumer_arv_count = (_resolve_warp_size() + Int((mul config.cluster_shape[Int(1)], config.cluster_shape[Int(0)], _resolve_warp_size(), 7)))

clc_producer_arv_count​

comptime clc_producer_arv_count = 1

clc_throttle_consumer_arv_count​

comptime clc_throttle_consumer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS

clc_throttle_producer_arv_count​

comptime clc_throttle_producer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS

CLUSTER_M​

comptime CLUSTER_M = config.cluster_shape[Int(0)]

CLUSTER_N​

comptime CLUSTER_N = config.cluster_shape[Int(1)]

CLUSTER_SIZE​

comptime CLUSTER_SIZE = (config.cluster_shape[Int(0)] * config.cluster_shape[Int(1)])

Context​

comptime Context = KernelContext[config.num_clc_pipeline_stages, config.cta_group, config.cluster_shape[Int(0)], config.cluster_shape[Int(1)]]

cta_group​

comptime cta_group = config.cta_group

epi_load_consumer_arv_count​

comptime epi_load_consumer_arv_count = SIMD(Int((mul _resolve_warp_size(), 4)))

epi_load_producer_arv_count​

comptime epi_load_producer_arv_count = Int32(1)

EpiLoadPipelineType​

comptime EpiLoadPipelineType = EpiLoadPipeline[(config.mma_shape[Int(1)] // config.output_tile_shape[Int(1)])]

EPILOGUE_LOAD_THREADS​

comptime EPILOGUE_LOAD_THREADS = WARP_SIZE

EPILOGUE_THREADS​

comptime EPILOGUE_THREADS = (Int(4) * _resolve_warp_size())

EpilogueCtx​

comptime EpilogueCtx = EpilogueWarpContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, _resolve_warp_size(), Int((mul _resolve_warp_size(), 4))]

filter_expected_bytes​

comptime filter_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.filter_smem_elements * size_of[filter_type]())

filter_swizzle_elems​

comptime filter_swizzle_elems = (config.b_swizzle.bytes() // size_of[filter_type]())

filter_tile_dim0​

comptime filter_tile_dim0 = (config.block_tile_shape[Int(1)] // (config.cluster_shape[Int(0)] // config))

filter_tma_load_size​

comptime filter_tma_load_size = ((config.block_tile_shape[Int(1)] // (config.cluster_shape[Int(0)] // config)) * (config.b_swizzle.bytes() // size_of[filter_type]()))

filter_tma_rows​

comptime filter_tma_rows = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].filter_tile_dim0

FilterDescLayout​

comptime FilterDescLayout = Layout[*?, *?]

FilterTileLayout​

comptime FilterTileLayout = Layout[*?, *?]

FilterTileLoaderType​

comptime FilterTileLoaderType = TileLoaderTMA[_, _, _, _, _, cta_group=config.cta_group]

FilterTmaOp​

comptime FilterTmaOp = TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

input_expected_bytes​

comptime input_expected_bytes = (Int((add (mul (config.block_tile_shape[Int(2)] // (config.a_swizzle.bytes() // size_of[act_type]())), (config.block_tile_shape[Int(0)] // Int(8)), (config.a_swizzle.bytes() // size_of[act_type]()), size_of[act_type](), config.cta_group, 8), (mul (config.block_tile_shape[Int(2)] // (config.b_swizzle.bytes() // size_of[filter_type]())), (config.block_tile_shape[Int(1)] // Int(8)), (config.b_swizzle.bytes() // size_of[filter_type]()), size_of[filter_type](), config.cta_group, 8))) * config)

InputTilePipelineType​

comptime InputTilePipelineType = InputTilePipeline[StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size]

MMA_K​

comptime MMA_K = config.mma_shape[Int(2)]

MMA_M​

comptime MMA_M = config.mma_shape[Int(0)]

MMA_N​

comptime MMA_N = config.mma_shape[Int(1)]

MMA_THREADS​

comptime MMA_THREADS = WARP_SIZE

MmaCtx​

comptime MmaCtx = MmaWarpContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, _resolve_warp_size(), Int((mul _resolve_warp_size(), 4))]

MmaEpilogueSync​

comptime MmaEpilogueSync = WarpGroupBarrier[(_resolve_warp_size() + Int((mul _resolve_warp_size(), 4))), Int(1)]

MmaOp​

comptime MmaOp = MmaOpSM100_SS[out_type, act_type, filter_type, config.block_tile_shape, config.mma_shape, accum_type=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, cta_group=config.cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=True]

num_accum_pipeline_stages​

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages​

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_epi_load_stages​

comptime num_epi_load_stages = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_epi_load_stages

num_group_pipeline_stages​

comptime num_group_pipeline_stages = (config // config)

num_output_stages​

comptime num_output_stages = config.num_output_stages

num_output_warps​

comptime num_output_warps = 4

num_pipeline_stages​

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS​

comptime NUM_THREADS = (Int((mul _resolve_warp_size(), 4)) + Int((mul _resolve_warp_size(), 4)))

NUM_TMEM_COLS​

comptime NUM_TMEM_COLS = 512

opc​

comptime opc = OutputPipelineConfig(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group)

out_swizzle_elems​

comptime out_swizzle_elems = (config.c_swizzle.bytes() // size_of[out_type]())

out_tile_dim0​

comptime out_tile_dim0 = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutputM if (config.mma_shape[Int(0)] == Int(256)) if (config.mma_shape[Int(0)] == Int(256)) else (config == Int(1)) else Int(64)

OutDescLayout​

comptime OutDescLayout = Layout[*?, *?]

OutputM​

comptime OutputM = config.output_tile_shape[Int(0)]

OutputN​

comptime OutputN = config.output_tile_shape[Int(1)]

OutputPipeline​

comptime OutputPipeline = OutputTilePipeline[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc]

OutTileLayout​

comptime OutTileLayout = Layout[*?, *?]

OutTmaOp​

comptime OutTmaOp = TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

Scheduler​

comptime Scheduler = TileScheduler[config.num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[Int(0)], config.cluster_shape[Int(1)], config.cluster_shape[Int(2)]), block_swizzle_size=config.block_swizzle_size]

SCHEDULER_THREADS​

comptime SCHEDULER_THREADS = WARP_SIZE

SmemType​

comptime SmemType = Conv2dSmem[act_type, filter_type, out_type, config=config]

src_expected_bytes​

comptime src_expected_bytes = (Int((mul config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)])) * size_of[out_type]())

SrcCTileArray​

comptime SrcCTileArray = SMemTileArray2DRowMajor[out_type, config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], (config.mma_shape[Int(1)] // config.output_tile_shape[Int(1)])]

SrcDescLayout​

comptime SrcDescLayout = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutDescLayout

SrcTileLayout​

comptime SrcTileLayout = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutTileLayout

SrcTileLoaderType​

comptime SrcTileLoaderType = TileLoaderTMA[_, _, _, _, _, cta_group=Int(1)]

SrcTmaOp​

comptime SrcTmaOp = TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

stage_stride_cols​

comptime stage_stride_cols = (Int(512) // config)

TilePayload​

comptime TilePayload = StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages]

TileWriterType​

comptime TileWriterType = TileWriter[act_type, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, config.block_tile_shape, config.mma_shape, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, config.c_swizzle, False, config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], config.num_output_stages, Int(4), elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue]

TMA_LOAD_THREADS​

comptime TMA_LOAD_THREADS = WARP_SIZE

Tmem​

comptime Tmem = TmemAllocation[OutputPipelineConfig(config.num_accum_pipeline_stages, (Int(512) // config), config.cta_group).cta_group]

TmemDealloc​

comptime TmemDealloc = TmemDeallocBarrier[OutputPipelineConfig(config.num_accum_pipeline_stages, (Int(512) // config), config.cta_group).cta_group]

Methods​

mma​

static def mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[Self.opc], tiles: ConsumerTiles[tiles_origin, StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], mma_op: MmaOpSM100_SS[accum_type=mma_op.accum_type, cta_group=mma_op.cta_group, cluster_shape=mma_op.cluster_shape, a_swizzle=mma_op.a_swizzle, b_swizzle=mma_op.b_swizzle, transpose_b=mma_op.transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)

Execute MMA operations for one pipeline stage.

init_barriers​

static def init_barriers(ctx: KernelContext[config.num_clc_pipeline_stages, config.cta_group, config.cluster_shape[Int(0)], config.cluster_shape[Int(1)]], act_tma_op: TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], input_barriers: SMemArray[SharedMemBarrier, ((config // config) * Int(2))], accum_barriers: SMemArray[SharedMemBarrier, (config * Int(2))], clc_throttle: SMemArray[SharedMemBarrier, (config * Int(2))], clc_full: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, Int(1)], epi_load_barriers: SMemArray[SharedMemBarrier, ((config.mma_shape[Int(1)] // config.output_tile_shape[Int(1)]) * Int(2))], load_order_barrier: SMemArray[SharedMemBarrier, Int(1)])

Initialize barriers and prefetch TMA descriptors.

load_input_tiles​

static def load_input_tiles[act_tma_origin: ImmutOrigin, filter_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](act_loader: TileLoaderTMAIm2col[act_tma_origin, act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]](), cta_group=config.cta_group], filter_loader: TileLoaderTMA[filter_tma_origin, filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]](), cta_group=config.cta_group], tiles: ProducerTiles[tiles_origin, StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], iter_idx: UInt32, work_m_coord: Int, work_n_coord: Int, peer_cta_coord: Tuple[Int, Int, Int], elect_one_cta: Bool)

Load activation (via im2col TMA) and filter tiles.

The im2col TMA descriptor handles coordinate transformation internally. Coordinates are in GEMM space:

  • work_m_coord: M coordinate (batch * H_out * W_out)
  • work_n_coord: N coordinate (output channels)
  • iter_idx: K dimension tile index (C * R * S)

run​

static def run(act_tma_op: TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], cluster_dim: StaticTuple[Int32, Int(3)], mnk: StaticTuple[UInt32, Int(3)])

Kernel entry point for Conv2D fprop (no residual).

Args:

run_with_residual​

static def run_with_residual(act_tma_op: TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], src_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], cluster_dim: StaticTuple[Int32, Int(3)], mnk: StaticTuple[UInt32, Int(3)], beta: Float32)

Kernel entry point for Conv2D fprop with residual (D = Conv + beta*C).

Args: