Skip to main content

Mojo struct

Conv2dFpropKernel

struct Conv2dFpropKernel[act_type: DType, filter_type: DType, out_type: DType, act_layout: Layout, filter_layout: Layout, out_layout: Layout, act_desc_layout: Layout, filter_desc_layout: Layout, out_desc_layout: Layout, config: Conv2dConfig[act_type, filter_type, out_type], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, src_layout: Layout = out_layout, src_desc_layout: Layout = out_desc_layout]

SM100 Conv2D forward propagation kernel.

This kernel implements conv2d fprop using implicit GEMM with warp specialization. It reuses the matmul kernel architecture but with convolution-specific address calculation.

The kernel structure:

  • Scheduler warp: CLC-based tile scheduling
  • Load warp: TMA loads with im2col transformation
  • MMA warp: Tensor core operations
  • Epilogue warps: Output from TMEM to GMEM

Parameters

  • act_type (DType): Activation data type.
  • filter_type (DType): Filter data type.
  • out_type (DType): Output data type.
  • act_layout (Layout): Global memory activation layout.
  • filter_layout (Layout): Global memory filter layout.
  • out_layout (Layout): Global memory output layout.
  • act_desc_layout (Layout): TMA descriptor layout for activation.
  • filter_desc_layout (Layout): TMA descriptor layout for filter.
  • out_desc_layout (Layout): TMA descriptor layout for output.
  • config (Conv2dConfig): Kernel configuration.
  • cluster_shape (StaticTuple): CUDA cluster dimensions.
  • elementwise_compute_lambda_fn (Optional): Optional epilogue lambda for fusion (bias add, activation functions, residual connections).
  • register_based_epilogue (Bool): Whether to apply the lambda in registers.
  • src_layout (Layout): Global memory layout for source C (residual input).
  • src_desc_layout (Layout): TMA descriptor layout for source C.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

accum_layout

comptime accum_layout = Layout.row_major(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_N)

accum_pipeline_consumer_arv_count

comptime accum_pipeline_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)

accum_pipeline_producer_arv_count

comptime accum_pipeline_producer_arv_count = 1

accum_type

comptime accum_type = Conv2dConfig.accum_type[act_type, filter_type, out_type]()

AccumTensor

comptime AccumTensor = TmemTensor[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]

act_expected_bytes

comptime act_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.act_smem_layout.size() * size_of[act_type]())

act_tma_load_size

comptime act_tma_load_size = act_desc_layout.size()

act_tma_rows

comptime act_tma_rows = act_desc_layout.shape[0].value()

ActTileLoaderTypeIm2col

comptime ActTileLoaderTypeIm2col = TileLoaderTMAIm2col[?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]

BK

comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)

c_smem_layout

comptime c_smem_layout = Layout.row_major(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputM, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputN)

clc_consumer_arv_count

comptime clc_consumer_arv_count = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS + (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_SIZE * (((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)))

clc_producer_arv_count

comptime clc_producer_arv_count = 1

clc_throttle_consumer_arv_count

comptime clc_throttle_consumer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS

clc_throttle_producer_arv_count

comptime clc_throttle_producer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS

CLUSTER_M

comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)

CLUSTER_N

comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)

CLUSTER_SIZE

comptime CLUSTER_SIZE = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N)

Context

comptime Context = KernelContext[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_clc_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N]

cta_group

comptime cta_group = config.cta_group

epi_load_consumer_arv_count

comptime epi_load_consumer_arv_count = SIMD[DType.int32, 1](Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)

epi_load_producer_arv_count

comptime epi_load_producer_arv_count = 1

EpiLoadPipelineType

comptime EpiLoadPipelineType = EpiLoadPipeline[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_epi_load_stages]

EPILOGUE_LOAD_THREADS

comptime EPILOGUE_LOAD_THREADS = WARP_SIZE

EPILOGUE_THREADS

comptime EPILOGUE_THREADS = (4 * WARP_SIZE)

EpilogueCtx

comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS]

filter_expected_bytes

comptime filter_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.filter_smem_layout.size() * size_of[filter_type]())

filter_tma_load_size

comptime filter_tma_load_size = filter_desc_layout.size()

filter_tma_rows

comptime filter_tma_rows = filter_desc_layout.shape[0].value()

FilterTileLoaderType

comptime FilterTileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]

input_expected_bytes

comptime input_expected_bytes = ((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group * (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].act_expected_bytes + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].filter_expected_bytes)) * config)

InputTilePipelineType

comptime InputTilePipelineType = InputTilePipeline[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TilePayload, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_group_pipeline_stages, config.k_group_size]

MMA_K

comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)

MMA_M

comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)

MMA_THREADS

comptime MMA_THREADS = WARP_SIZE

MmaCtx

comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS]

MmaEpilogueSync

comptime MmaEpilogueSync = WarpGroupBarrier[(Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS), 1]

MmaOp

comptime MmaOp = MmaOpSM100_SS[out_type, act_type, filter_type, config.block_tile_shape, config.mma_shape, accum_type=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_type, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=True]

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_epi_load_stages

comptime num_epi_load_stages = Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_epi_load_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_output_warps

comptime num_output_warps = 4

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS

comptime NUM_THREADS = ((((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SCHEDULER_THREADS + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TMA_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].MMA_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_LOAD_THREADS) + Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].EPILOGUE_THREADS)

NUM_TMEM_COLS

comptime NUM_TMEM_COLS = 512

OutputM

comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)

OutputN

comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)

OutputPipeline

comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]

Scheduler

comptime Scheduler = TileScheduler[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), block_swizzle_size=config.block_swizzle_size]

SCHEDULER_THREADS

comptime SCHEDULER_THREADS = WARP_SIZE

SmemType

comptime SmemType = Conv2dSmem[act_type, filter_type, out_type, config=config]

src_expected_bytes

comptime src_expected_bytes = ((Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].OutputM * Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].OutputN) * size_of[out_type]())

SrcCTileArray

comptime SrcCTileArray = SMemTileArray[out_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].c_smem_layout, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_output_stages, 128]

SrcTileLoaderType

comptime SrcTileLoaderType = TileLoaderTMA[?, ?, ?, ?, cta_group=1]

stage_stride_cols

comptime stage_stride_cols = (512 // Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_accum_pipeline_stages)

TilePayload

comptime TilePayload = StandardTilePayload[act_type, filter_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BM, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BK, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BN, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].BK, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_pipeline_stages]

TileWriterType

comptime TileWriterType = TileWriter[act_type, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].accum_type, config.block_tile_shape, config.mma_shape, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, False, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputM, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.OutputN, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_output_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, 4, elementwise_compute_lambda_fn, register_based_epilogue]

TMA_LOAD_THREADS

comptime TMA_LOAD_THREADS = WARP_SIZE

Tmem

comptime Tmem = TmemAllocation[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]

TmemDealloc

comptime TmemDealloc = TmemDeallocBarrier[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group]

Methods

mma

static mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[config.num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group], tiles: InputConsumerStage[tiles_origin, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TilePayload, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)

Execute MMA operations for one pipeline stage.

init_barriers

static init_barriers(ctx: KernelContext[Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].num_clc_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_M, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].CLUSTER_N], act_tma_op: TMATensorTileIm2col[act_type, act_layout, act_desc_layout], filter_tma_op: TMATensorTile[filter_type, filter_layout, filter_desc_layout], out_tma_op: TMATensorTile[out_type, out_layout, out_desc_layout], input_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages * 2)], clc_full: SMemArray[SharedMemBarrier, Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1], epi_load_barriers: SMemArray[SharedMemBarrier, (Conv2dSmem[act_type, filter_type, out_type, config=config].num_epi_load_stages * 2)], load_order_barrier: SMemArray[SharedMemBarrier, 1])

Initialize barriers and prefetch TMA descriptors.

load_input_tiles

static load_input_tiles[act_tma_origin: ImmutOrigin, filter_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](act_loader: TileLoaderTMAIm2col[act_tma_origin, act_type, act_layout, act_desc_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group], filter_loader: TileLoaderTMA[filter_tma_origin, filter_type, filter_layout, filter_desc_layout, cta_group=Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].cta_group], tiles: InputProducerStage[tiles_origin, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].TilePayload, Conv2dFpropKernel[act_type, filter_type, out_type, act_layout, filter_layout, out_layout, act_desc_layout, filter_desc_layout, out_desc_layout, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, src_layout, src_desc_layout].SmemType.num_group_pipeline_stages, config.k_group_size], iter_idx: UInt32, work_m_coord: Scalar[DType.uint], work_n_coord: Scalar[DType.uint], peer_cta_coord: Tuple[UInt, UInt, UInt], elect_one_cta: Bool)

Load activation (via im2col TMA) and filter tiles.

The im2col TMA descriptor handles coordinate transformation internally. Coordinates are in GEMM space:

  • work_m_coord: M coordinate (batch * H_out * W_out)
  • work_n_coord: N coordinate (output channels)
  • iter_idx: K dimension tile index (C * R * S)

run

static run(act_tma_op: TMATensorTileIm2col[act_type, act_layout, act_desc_layout], filter_tma_op: TMATensorTile[filter_type, filter_layout, filter_desc_layout], out_tma_op: TMATensorTile[out_type, out_layout, out_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3])

Kernel entry point for Conv2D fprop (no residual).

Args:

run_with_residual

static run_with_residual(act_tma_op: TMATensorTileIm2col[act_type, act_layout, act_desc_layout], filter_tma_op: TMATensorTile[filter_type, filter_layout, filter_desc_layout], out_tma_op: TMATensorTile[out_type, out_layout, out_desc_layout], src_tma_op: TMATensorTile[out_type, src_layout, src_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], beta: Float32)

Kernel entry point for Conv2D fprop with residual (D = Conv + beta*C).

Args:

Was this page helpful?