For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Conv2dFpropKernel
struct Conv2dFpropKernel[act_type: DType, filter_type: DType, out_type: DType, config: Conv2dConfig[act_type, filter_type, out_type], cluster_shape: StaticTuple[Int32, Int(3)] = StaticTuple(Int32(1)), elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True]
SM100 Conv2D forward propagation kernel.
This kernel implements conv2d fprop using implicit GEMM with warp specialization. It reuses the matmul kernel architecture but with convolution-specific address calculation.
The kernel structure:
- Scheduler warp: CLC-based tile scheduling
- Load warp: TMA loads with im2col transformation
- MMA warp: Tensor core operations
- Epilogue warps: Output from TMEM to GMEM
Parametersβ
- βact_type (
DType): Activation data type. - βfilter_type (
DType): Filter data type. - βout_type (
DType): Output data type. - βconfig (
Conv2dConfig[act_type, filter_type, out_type]): Kernel configuration. - βcluster_shape (
StaticTuple[Int32, Int(3)]): CUDA cluster dimensions. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None]): Optional void epilogue lambda applied after output write. Signature:def(IndexList[2], SIMD) -> None. - βelementwise_compute_lambda_fn (
Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]): Optional epilogue lambda for fusion (bias add, activation functions, residual connections). - βregister_based_epilogue (
Bool): Whether to apply the lambda in registers.
Implemented traitsβ
comptime membersβ
accum_layoutβ
comptime accum_layout = Layout.row_major(config.mma_shape[Int(0)], config.mma_shape[Int(1)])
accum_pipeline_consumer_arv_countβ
comptime accum_pipeline_consumer_arv_count = (config * Int((mul _resolve_warp_size(), 4)))
accum_pipeline_producer_arv_countβ
comptime accum_pipeline_producer_arv_count = 1
accum_typeβ
comptime accum_type = Conv2dConfig.accum_type()
AccumTensorβ
comptime AccumTensor = TmemTensor[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_layout, cta_group=config.cta_group]
act_expected_bytesβ
comptime act_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.act_smem_elements * size_of[act_type]())
act_swizzle_elemsβ
comptime act_swizzle_elems = (config.a_swizzle.bytes() // size_of[act_type]())
act_tile_dim0β
comptime act_tile_dim0 = (config.block_tile_shape[Int(0)] // config.cluster_shape[Int(1)])
act_tma_load_sizeβ
comptime act_tma_load_size = ((config.block_tile_shape[Int(0)] // config.cluster_shape[Int(1)]) * (config.a_swizzle.bytes() // size_of[act_type]()))
act_tma_rowsβ
comptime act_tma_rows = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].act_tile_dim0
ActDescLayoutβ
comptime ActDescLayout = Layout[*?, *?]
ActTileLayoutβ
comptime ActTileLayout = Layout[*?, *?]
ActTileLoaderTypeIm2colβ
comptime ActTileLoaderTypeIm2col = TileLoaderTMAIm2col[_, _, _, _, _, cta_group=config.cta_group]
ActTmaOpβ
comptime ActTmaOp = TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]
BKβ
comptime BK = config.block_tile_shape[Int(2)]
BMβ
comptime BM = config.block_tile_shape[Int(0)]
BNβ
comptime BN = config.block_tile_shape[Int(1)]
clc_consumer_arv_countβ
comptime clc_consumer_arv_count = (_resolve_warp_size() + Int((mul config.cluster_shape[Int(1)], config.cluster_shape[Int(0)], _resolve_warp_size(), 7)))
clc_producer_arv_countβ
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_countβ
comptime clc_throttle_consumer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS
clc_throttle_producer_arv_countβ
comptime clc_throttle_producer_arv_count = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS
CLUSTER_Mβ
comptime CLUSTER_M = config.cluster_shape[Int(0)]
CLUSTER_Nβ
comptime CLUSTER_N = config.cluster_shape[Int(1)]
CLUSTER_SIZEβ
comptime CLUSTER_SIZE = (config.cluster_shape[Int(0)] * config.cluster_shape[Int(1)])
Contextβ
comptime Context = KernelContext[config.num_clc_pipeline_stages, config.cta_group, config.cluster_shape[Int(0)], config.cluster_shape[Int(1)]]
cta_groupβ
comptime cta_group = config.cta_group
epi_load_consumer_arv_countβ
comptime epi_load_consumer_arv_count = SIMD(Int((mul _resolve_warp_size(), 4)))
epi_load_producer_arv_countβ
comptime epi_load_producer_arv_count = Int32(1)
EpiLoadPipelineTypeβ
comptime EpiLoadPipelineType = EpiLoadPipeline[(config.mma_shape[Int(1)] // config.output_tile_shape[Int(1)])]
EPILOGUE_LOAD_THREADSβ
comptime EPILOGUE_LOAD_THREADS = WARP_SIZE
EPILOGUE_THREADSβ
comptime EPILOGUE_THREADS = (Int(4) * _resolve_warp_size())
EpilogueCtxβ
comptime EpilogueCtx = EpilogueWarpContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, _resolve_warp_size(), Int((mul _resolve_warp_size(), 4))]
filter_expected_bytesβ
comptime filter_expected_bytes = (Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.filter_smem_elements * size_of[filter_type]())
filter_swizzle_elemsβ
comptime filter_swizzle_elems = (config.b_swizzle.bytes() // size_of[filter_type]())
filter_tile_dim0β
comptime filter_tile_dim0 = (config.block_tile_shape[Int(1)] // (config.cluster_shape[Int(0)] // config))
filter_tma_load_sizeβ
comptime filter_tma_load_size = ((config.block_tile_shape[Int(1)] // (config.cluster_shape[Int(0)] // config)) * (config.b_swizzle.bytes() // size_of[filter_type]()))
filter_tma_rowsβ
comptime filter_tma_rows = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].filter_tile_dim0
FilterDescLayoutβ
comptime FilterDescLayout = Layout[*?, *?]
FilterTileLayoutβ
comptime FilterTileLayout = Layout[*?, *?]
FilterTileLoaderTypeβ
comptime FilterTileLoaderType = TileLoaderTMA[_, _, _, _, _, cta_group=config.cta_group]
FilterTmaOpβ
comptime FilterTmaOp = TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]
input_expected_bytesβ
comptime input_expected_bytes = (Int((add (mul (config.block_tile_shape[Int(2)] // (config.a_swizzle.bytes() // size_of[act_type]())), (config.block_tile_shape[Int(0)] // Int(8)), (config.a_swizzle.bytes() // size_of[act_type]()), size_of[act_type](), config.cta_group, 8), (mul (config.block_tile_shape[Int(2)] // (config.b_swizzle.bytes() // size_of[filter_type]())), (config.block_tile_shape[Int(1)] // Int(8)), (config.b_swizzle.bytes() // size_of[filter_type]()), size_of[filter_type](), config.cta_group, 8))) * config)
InputTilePipelineTypeβ
comptime InputTilePipelineType = InputTilePipeline[StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size]
MMA_Kβ
comptime MMA_K = config.mma_shape[Int(2)]
MMA_Mβ
comptime MMA_M = config.mma_shape[Int(0)]
MMA_Nβ
comptime MMA_N = config.mma_shape[Int(1)]
MMA_THREADSβ
comptime MMA_THREADS = WARP_SIZE
MmaCtxβ
comptime MmaCtx = MmaWarpContext[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, _resolve_warp_size(), Int((mul _resolve_warp_size(), 4))]
MmaEpilogueSyncβ
comptime MmaEpilogueSync = WarpGroupBarrier[(_resolve_warp_size() + Int((mul _resolve_warp_size(), 4))), Int(1)]
MmaOpβ
comptime MmaOp = MmaOpSM100_SS[out_type, act_type, filter_type, config.block_tile_shape, config.mma_shape, accum_type=Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, cta_group=config.cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=True]
num_accum_pipeline_stagesβ
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stagesβ
comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages
num_epi_load_stagesβ
comptime num_epi_load_stages = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_epi_load_stages
num_group_pipeline_stagesβ
comptime num_group_pipeline_stages = (config // config)
num_output_stagesβ
comptime num_output_stages = config.num_output_stages
num_output_warpsβ
comptime num_output_warps = 4
num_pipeline_stagesβ
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADSβ
comptime NUM_THREADS = (Int((mul _resolve_warp_size(), 4)) + Int((mul _resolve_warp_size(), 4)))
NUM_TMEM_COLSβ
comptime NUM_TMEM_COLS = 512
opcβ
comptime opc = OutputPipelineConfig(Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].cta_group)
out_swizzle_elemsβ
comptime out_swizzle_elems = (config.c_swizzle.bytes() // size_of[out_type]())
out_tile_dim0β
comptime out_tile_dim0 = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutputM if (config.mma_shape[Int(0)] == Int(256)) if (config.mma_shape[Int(0)] == Int(256)) else (config == Int(1)) else Int(64)
OutDescLayoutβ
comptime OutDescLayout = Layout[*?, *?]
OutputMβ
comptime OutputM = config.output_tile_shape[Int(0)]
OutputNβ
comptime OutputN = config.output_tile_shape[Int(1)]
OutputPipelineβ
comptime OutputPipeline = OutputTilePipeline[Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc]
OutTileLayoutβ
comptime OutTileLayout = Layout[*?, *?]
OutTmaOpβ
comptime OutTmaOp = TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]
Schedulerβ
comptime Scheduler = TileScheduler[config.num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[Int(0)], config.cluster_shape[Int(1)], config.cluster_shape[Int(2)]), block_swizzle_size=config.block_swizzle_size]
SCHEDULER_THREADSβ
comptime SCHEDULER_THREADS = WARP_SIZE
SmemTypeβ
comptime SmemType = Conv2dSmem[act_type, filter_type, out_type, config=config]
src_expected_bytesβ
comptime src_expected_bytes = (Int((mul config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)])) * size_of[out_type]())
SrcCTileArrayβ
comptime SrcCTileArray = SMemTileArray2DRowMajor[out_type, config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], (config.mma_shape[Int(1)] // config.output_tile_shape[Int(1)])]
SrcDescLayoutβ
comptime SrcDescLayout = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutDescLayout
SrcTileLayoutβ
comptime SrcTileLayout = Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].OutTileLayout
SrcTileLoaderTypeβ
comptime SrcTileLoaderType = TileLoaderTMA[_, _, _, _, _, cta_group=Int(1)]
SrcTmaOpβ
comptime SrcTmaOp = TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]
stage_stride_colsβ
comptime stage_stride_cols = (Int(512) // config)
TilePayloadβ
comptime TilePayload = StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages]
TileWriterTypeβ
comptime TileWriterType = TileWriter[act_type, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, config.block_tile_shape, config.mma_shape, Conv2dFpropKernel[act_type, filter_type, out_type, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue].opc, config.c_swizzle, False, config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], config.num_output_stages, Int(4), elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue]
TMA_LOAD_THREADSβ
comptime TMA_LOAD_THREADS = WARP_SIZE
Tmemβ
comptime Tmem = TmemAllocation[OutputPipelineConfig(config.num_accum_pipeline_stages, (Int(512) // config), config.cta_group).cta_group]
TmemDeallocβ
comptime TmemDealloc = TmemDeallocBarrier[OutputPipelineConfig(config.num_accum_pipeline_stages, (Int(512) // config), config.cta_group).cta_group]
Methodsβ
mmaβ
static def mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[Self.opc], tiles: ConsumerTiles[tiles_origin, StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], mma_op: MmaOpSM100_SS[accum_type=mma_op.accum_type, cta_group=mma_op.cta_group, cluster_shape=mma_op.cluster_shape, a_swizzle=mma_op.a_swizzle, b_swizzle=mma_op.b_swizzle, transpose_b=mma_op.transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)
Execute MMA operations for one pipeline stage.
init_barriersβ
static def init_barriers(ctx: KernelContext[config.num_clc_pipeline_stages, config.cta_group, config.cluster_shape[Int(0)], config.cluster_shape[Int(1)]], act_tma_op: TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], input_barriers: SMemArray[SharedMemBarrier, ((config // config) * Int(2))], accum_barriers: SMemArray[SharedMemBarrier, (config * Int(2))], clc_throttle: SMemArray[SharedMemBarrier, (config * Int(2))], clc_full: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, Int(1)], epi_load_barriers: SMemArray[SharedMemBarrier, ((config.mma_shape[Int(1)] // config.output_tile_shape[Int(1)]) * Int(2))], load_order_barrier: SMemArray[SharedMemBarrier, Int(1)])
Initialize barriers and prefetch TMA descriptors.
load_input_tilesβ
static def load_input_tiles[act_tma_origin: ImmutOrigin, filter_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](act_loader: TileLoaderTMAIm2col[act_tma_origin, act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]](), cta_group=config.cta_group], filter_loader: TileLoaderTMA[filter_tma_origin, filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]](), cta_group=config.cta_group], tiles: ProducerTiles[tiles_origin, StandardTilePayload[act_type, filter_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], iter_idx: UInt32, work_m_coord: Int, work_n_coord: Int, peer_cta_coord: Tuple[Int, Int, Int], elect_one_cta: Bool)
Load activation (via im2col TMA) and filter tiles.
The im2col TMA descriptor handles coordinate transformation internally. Coordinates are in GEMM space:
- work_m_coord: M coordinate (batch * H_out * W_out)
- work_n_coord: N coordinate (output channels)
- iter_idx: K dimension tile index (C * R * S)
runβ
static def run(act_tma_op: TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], cluster_dim: StaticTuple[Int32, Int(3)], mnk: StaticTuple[UInt32, Int(3)])
Kernel entry point for Conv2D fprop (no residual).
Args:
- βact_tma_op (
TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): Im2col TMA descriptor for activation. - βfilter_tma_op (
TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): TMA descriptor for filter. - βout_tma_op (
TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): TMA descriptor for output. - βcluster_dim (
StaticTuple[Int32, Int(3)]): Cluster dimensions. - βmnk (
StaticTuple[UInt32, Int(3)]): GEMM dimensions (M, N, K).
run_with_residualβ
static def run_with_residual(act_tma_op: TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], filter_tma_op: TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], out_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], src_tma_op: TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], cluster_dim: StaticTuple[Int32, Int(3)], mnk: StaticTuple[UInt32, Int(3)], beta: Float32)
Kernel entry point for Conv2D fprop with residual (D = Conv + beta*C).
Args:
- βact_tma_op (
TMATensorTileIm2col[act_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): Im2col TMA descriptor for activation. - βfilter_tma_op (
TMATensorTile[filter_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): TMA descriptor for filter. - βout_tma_op (
TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): TMA descriptor for output D. - βsrc_tma_op (
TMATensorTile[out_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): TMA descriptor for source C (residual input). - βcluster_dim (
StaticTuple[Int32, Int(3)]): Cluster dimensions. - βmnk (
StaticTuple[UInt32, Int(3)]): GEMM dimensions (M, N, K). - βbeta (
Float32): Residual scale factor.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!