IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlackwellMatmulSM100Kernel

struct BlackwellMatmulSM100Kernel[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, Int(3)] = StaticTuple(Int32(1)), elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = UInt32(0)]

Blackwell SM100 GEMM kernel with warp specialization.

This struct unifies all parameters and derived types for the SM100 matmul kernel, providing:

  • Compile-time parameter validation
  • Centralized derived type computation
  • Factory methods for kernel components
  • Multiple kernel entry points (standard, split-k)

The SM100 kernel uses:

  • Tensor Memory (TMEM) for MMA accumulators
  • Cluster Launch Control (CLC) for dynamic tile scheduling
  • Warp specialization: Scheduler, TMA Load, MMA, Epilogue warps
  • Software pipelining for overlapping compute and memory operations

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

a_expected_bytes​

comptime a_expected_bytes = (Int((mul (config.block_tile_shape[Int(2)] // (config.a_swizzle.bytes() // size_of[a_type]())), (config.block_tile_shape[Int(0)] // Int(8)), (config.a_swizzle.bytes() // size_of[a_type]()), 8)) * size_of[a_type]())

a_smem_layout​

comptime a_smem_layout = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))).to_layout()

a_swizzle_elems​

comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())

a_tile_dim0​

comptime a_tile_dim0 = (config.block_tile_shape[Int(0)] // config.cluster_shape[Int(1)])

a_tma_load_size​

comptime a_tma_load_size = ((config.block_tile_shape[Int(0)] // config.cluster_shape[Int(1)]) * (config.a_swizzle.bytes() // size_of[a_type]()))

a_tma_rows​

comptime a_tma_rows = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0

accum_layout​

comptime accum_layout = Layout.row_major(config.mma_shape[Int(0)], config.mma_shape[Int(1)])

accum_pipeline_consumer_arv_count​

comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[Int((mul _resolve_warp_size(), 4)), config.cta_group]()[Int(1)]

accum_pipeline_producer_arv_count​

comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[Int((mul _resolve_warp_size(), 4)), config.cta_group]()[Int(0)]

accum_type​

comptime accum_type = MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type

AccumTensor​

comptime AccumTensor = TmemTensor[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_layout, cta_group=config.cta_group]

ADescLayout​

comptime ADescLayout = Layout[*?, *?]

ADescLayout_splitk​

comptime ADescLayout_splitk = Layout[*?, *?]

ATileLayout​

comptime ATileLayout = Layout[*?, *?]

ATileLayout_splitk​

comptime ATileLayout_splitk = Layout[*?, *?]

ATmaOp​

comptime ATmaOp = TMATensorTile[a_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()]

ATmaOp_splitk​

comptime ATmaOp_splitk = TMATensorTile[a_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

b_expected_bytes​

comptime b_expected_bytes = (Int((mul (config.block_tile_shape[Int(2)] // (config.b_swizzle.bytes() // size_of[b_type]())), (config.block_tile_shape[Int(1)] // Int(8)), (config.b_swizzle.bytes() // size_of[b_type]()), 8)) * size_of[b_type]())

b_smem_layout​

comptime b_smem_layout = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))).to_layout() if transpose_b else Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))).transpose().to_layout()

b_swizzle_elems​

comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())

b_tile_dim0​

comptime b_tile_dim0 = (config.block_tile_shape[Int(1)] // (config.cluster_shape[Int(0)] // config))

b_tma_load_size​

comptime b_tma_load_size = ((config.block_tile_shape[Int(1)] // (config.cluster_shape[Int(0)] // config)) * (config.b_swizzle.bytes() // size_of[b_type]()))

b_tma_rows​

comptime b_tma_rows = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0

BDescLayout​

comptime BDescLayout = Layout[*?, *?]

BDescLayout_splitk​

comptime BDescLayout_splitk = Layout[*?, *?]

Bias1DTile​

comptime Bias1DTile = TileTensor[c_type, Layout[*?, *?], ImmutAnyOrigin]

Bias1DTileLayout​

comptime Bias1DTileLayout = row_major[Int(1), config.mma_shape[Int(1)]]()

BK​

comptime BK = config.block_tile_shape[Int(2)]

BM​

comptime BM = config.block_tile_shape[Int(0)]

BN​

comptime BN = config.block_tile_shape[Int(1)]

BTileLayout​

comptime BTileLayout = Layout[*?, *?]

BTileLayout_splitk​

comptime BTileLayout_splitk = Layout[*?, *?]

BTmaOp​

comptime BTmaOp = TMATensorTile[b_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()]

BTmaOp_splitk​

comptime BTmaOp_splitk = TMATensorTile[b_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

c_swizzle_elems​

comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())

c_tile_dim0​

comptime c_tile_dim0 = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputM if (config.mma_shape[Int(0)] == Int(256)) if (config.mma_shape[Int(0)] == Int(256)) else (config == Int(1)) or config.AB_swapped else Int(64)

c_tile_dim1​

comptime c_tile_dim1 = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_swizzle_elems if config.AB_swapped else BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputN

CDescLayout​

comptime CDescLayout = Layout[*?, *?]

CDescLayout_splitk​

comptime CDescLayout_splitk = Layout[*?, *?]

clc_consumer_arv_count​

comptime clc_consumer_arv_count = compute_clc_barrier_counts[_resolve_warp_size(), _resolve_warp_size(), _resolve_warp_size(), (Int((mul _resolve_warp_size(), 4)) + _resolve_warp_size() if config.use_tma_epilogue_load else Int(0)), Int((mul config.cluster_shape[Int(0)], config.cluster_shape[Int(1)])), config.cta_group]()[Int(1)]

clc_producer_arv_count​

comptime clc_producer_arv_count = compute_clc_barrier_counts[_resolve_warp_size(), _resolve_warp_size(), _resolve_warp_size(), (Int((mul _resolve_warp_size(), 4)) + _resolve_warp_size() if config.use_tma_epilogue_load else Int(0)), Int((mul config.cluster_shape[Int(0)], config.cluster_shape[Int(1)])), config.cta_group]()[Int(0)]

clc_throttle_consumer_arv_count​

comptime clc_throttle_consumer_arv_count = compute_clc_barrier_counts[_resolve_warp_size(), _resolve_warp_size(), _resolve_warp_size(), (Int((mul _resolve_warp_size(), 4)) + _resolve_warp_size() if config.use_tma_epilogue_load else Int(0)), Int((mul config.cluster_shape[Int(0)], config.cluster_shape[Int(1)])), config.cta_group]()[Int(3)]

clc_throttle_producer_arv_count​

comptime clc_throttle_producer_arv_count = compute_clc_barrier_counts[_resolve_warp_size(), _resolve_warp_size(), _resolve_warp_size(), (Int((mul _resolve_warp_size(), 4)) + _resolve_warp_size() if config.use_tma_epilogue_load else Int(0)), Int((mul config.cluster_shape[Int(0)], config.cluster_shape[Int(1)])), config.cta_group]()[Int(2)]

CLUSTER_M​

comptime CLUSTER_M = config.cluster_shape[Int(0)]

CLUSTER_N​

comptime CLUSTER_N = config.cluster_shape[Int(1)]

CLUSTER_SIZE​

comptime CLUSTER_SIZE = (config.cluster_shape[Int(0)] * config.cluster_shape[Int(1)])

Context​

comptime Context = KernelContext[config.num_clc_pipeline_stages, config.cta_group, config.cluster_shape[Int(0)], config.cluster_shape[Int(1)]]

cta_group​

comptime cta_group = config.cta_group

CTileLayout​

comptime CTileLayout = Layout[*?, *?]

CTileLayout_splitk​

comptime CTileLayout_splitk = Layout[*?, *?]

CTmaOp​

comptime CTmaOp = TMATensorTile[c_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()]

CTmaOp_splitk​

comptime CTmaOp_splitk = TMATensorTile[c_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

epi_load_consumer_arv_count​

comptime epi_load_consumer_arv_count = SIMD(Int((mul _resolve_warp_size(), 4)))

epi_load_producer_arv_count​

comptime epi_load_producer_arv_count = SIMD(ceildiv(config.block_tile_shape[Int(0)] if config.AB_swapped else config.mma_shape[Int(1)], Int(8))) if config.epilogue_is_1d else Int32(1)

epi_load_swizzle​

comptime epi_load_swizzle = config.epi_load_swizzle

epi_load_swizzle_elems​

comptime epi_load_swizzle_elems = (config.epi_load_swizzle.bytes() // size_of[c_type]())

EPILOGUE_LOAD_THREADS​

comptime EPILOGUE_LOAD_THREADS = WARP_SIZE if config.use_tma_epilogue_load else Int(0)

EPILOGUE_THREADS​

comptime EPILOGUE_THREADS = (Int(4) * _resolve_warp_size())

EpilogueCtx​

comptime EpilogueCtx = EpilogueWarpContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, _resolve_warp_size(), Int((mul _resolve_warp_size(), 4))]

EpilogueLoadDescLayout​

comptime EpilogueLoadDescLayout = Layout[*?, *?]

EpilogueLoadTileLayout​

comptime EpilogueLoadTileLayout = Layout[*?, *?]

EpilogueLoadTmaOp​

comptime EpilogueLoadTmaOp = TMATensorTile[c_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

input_expected_bytes​

comptime input_expected_bytes = (Int((add (mul (config.block_tile_shape[Int(2)] // (config.a_swizzle.bytes() // size_of[a_type]())), (config.block_tile_shape[Int(0)] // Int(8)), (config.a_swizzle.bytes() // size_of[a_type]()), size_of[a_type](), config.cta_group, 8), (mul (config.block_tile_shape[Int(2)] // (config.b_swizzle.bytes() // size_of[b_type]())), (config.block_tile_shape[Int(1)] // Int(8)), (config.b_swizzle.bytes() // size_of[b_type]()), size_of[b_type](), config.cta_group, 8))) * config)

InputTilePipeline​

comptime InputTilePipeline = InputTilePipeline[StandardTilePayload[a_type, b_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size]

MMA_K​

comptime MMA_K = config.mma_shape[Int(2)]

MMA_M​

comptime MMA_M = config.mma_shape[Int(0)]

MMA_N​

comptime MMA_N = config.mma_shape[Int(1)]

MMA_THREADS​

comptime MMA_THREADS = WARP_SIZE

MmaCtx​

comptime MmaCtx = MmaWarpContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, _resolve_warp_size(), Int((mul _resolve_warp_size(), 4))]

MmaEpilogueSync​

comptime MmaEpilogueSync = WarpGroupBarrier[(_resolve_warp_size() + Int((mul _resolve_warp_size(), 4))), Int(1)]

MmaOp​

comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, accum_type=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=config.cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

num_accum_pipeline_stages​

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages​

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_group_pipeline_stages​

comptime num_group_pipeline_stages = (config // config)

num_k_mmas​

comptime num_k_mmas = (config.block_tile_shape[Int(2)] // config.mma_shape[Int(2)])

num_m_mmas​

comptime num_m_mmas = (config.block_tile_shape[Int(0)] // (config.mma_shape[Int(0)] // config))

num_n_mmas​

comptime num_n_mmas = (config.block_tile_shape[Int(1)] // (config.mma_shape[Int(1)] // config))

num_output_stages​

comptime num_output_stages = config.num_output_stages

num_output_warps​

comptime num_output_warps = 4

num_pipeline_stages​

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS​

comptime NUM_THREADS = (Int((mul _resolve_warp_size(), 7)) + _resolve_warp_size() if config.use_tma_epilogue_load else Int(0))

NUM_TMEM_COLS​

comptime NUM_TMEM_COLS = 512

opc​

comptime opc = OutputPipelineConfig(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group)

OutputM​

comptime OutputM = config.output_tile_shape[Int(0)]

OutputN​

comptime OutputN = config.output_tile_shape[Int(1)]

OutputPipeline​

comptime OutputPipeline = OutputTilePipeline[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc]

register_based_epilogue​

comptime register_based_epilogue = config.register_based_epilogue

Scheduler​

comptime Scheduler = TileScheduler[config.num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[Int(0)], config.cluster_shape[Int(1)], config.cluster_shape[Int(2)]), config.raster_order, config.block_swizzle_size]

SCHEDULER_THREADS​

comptime SCHEDULER_THREADS = WARP_SIZE

SmemType​

comptime SmemType = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config]

stage_stride_cols​

comptime stage_stride_cols = (Int(512) // config)

TilePayload​

comptime TilePayload = StandardTilePayload[a_type, b_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages]

TileWriterType​

comptime TileWriterType = TileWriter[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, config.block_tile_shape, config.mma_shape, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, config.c_swizzle, config.AB_swapped, config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], config.num_output_stages, Int(4), elementwise_lambda_fn, elementwise_compute_lambda_fn, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].register_based_epilogue, True]

TileWriterType_splitk​

comptime TileWriterType_splitk = TileWriter[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, config.block_tile_shape, config.mma_shape, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, config.c_swizzle, config.AB_swapped, config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], config.num_output_stages, Int(4), elementwise_lambda_fn, elementwise_compute_lambda_fn, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].register_based_epilogue]

TMA_LOAD_THREADS​

comptime TMA_LOAD_THREADS = WARP_SIZE

Tmem​

comptime Tmem = TmemAllocation[OutputPipelineConfig(config.num_accum_pipeline_stages, (Int(512) // config), config.cta_group).cta_group]

TmemDealloc​

comptime TmemDealloc = TmemDeallocBarrier[OutputPipelineConfig(config.num_accum_pipeline_stages, (Int(512) // config), config.cta_group).cta_group]

WorkIter​

comptime WorkIter = WorkIterator[config.num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[Int(0)], config.cluster_shape[Int(1)], config.cluster_shape[Int(2)]), config.raster_order, config.block_swizzle_size]

Methods​

validate_constraints​

static def validate_constraints()

Validate parameter constraints at compile time.

init_barriers​

static def init_barriers[use_tma_epilogue_load: Bool = False](ctx: KernelContext[config.num_clc_pipeline_stages, config.cta_group, config.cluster_shape[Int(0)], config.cluster_shape[Int(1)]], input_barriers: SMemArray[SharedMemBarrier, ((config // config) * Int(2))], accum_barriers: SMemArray[SharedMemBarrier, (config * Int(2))], clc_throttle: SMemArray[SharedMemBarrier, (config * Int(2))], clc_full: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, Int(1)], epi_load_barriers: SMemArray[SharedMemBarrier, (config.num_accum_pipeline_stages if config.AB_swapped or config.epilogue_is_1d else config.num_tma_epilogue_pipeline_stages if config.use_tma_epilogue_load else Int(0) * Int(2))] = SMemArray(UnsafePointer.unsafe_dangling()))

Initialize barriers. TMA descriptor prefetch is done by each kernel entry point before calling this method.

mma​

static def mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[Self.opc], tiles: ConsumerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], mma_op: MmaOpSM100_SS[accum_type=mma_op.accum_type, cta_group=mma_op.cta_group, cluster_shape=mma_op.cluster_shape, a_swizzle=mma_op.a_swizzle, b_swizzle=mma_op.b_swizzle, transpose_b=mma_op.transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)

Execute MMA operations for one pipeline stage.

This is the core MMA function designed to be called within a consumer stage context:

with consumer.acquire() as tiles:
    Self.mma(stage.tmem, tiles, mma_op, ...)

Args:

load_input_tiles​

static def load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()], tiles: ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)

Load A and B tiles using 3D TMA.

Uses async_multicast_load_3d with batch coordinate from work_tile_coord[2]. For non-batched calls, batch coord is 0 (grid_dim.z = 1).

Args:

prefetch_a_tiles​

static def prefetch_a_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()], tiles: ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)

Load A tiles only; set full expected bytes (A+B) on the barrier.

Called before wait_on_dependent_grids() to prefetch the static weight matrix (kernel-A in swapAB mode). The barrier will not fire until the matching complete_b_tiles() call delivers the remaining B bytes.

Args:

complete_b_tiles​

static def complete_b_tiles(b_tma_op: TMATensorTile[b_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()], stage: UInt32, barrier: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED], payload: StandardTilePayload[a_type, b_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], b_multicast_mask: UInt16, iter_idx: UInt32)

Load B tiles into a previously prefetched stage.

Delivers the remaining B bytes so that the stage barrier fires and the consumer can proceed. Pair with prefetch_a_tiles().

Args:

load_input_tiles_splitk​

static def load_input_tiles_splitk[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](a_loader: TileLoader[a_tma_origin, a_type, Layout[*?, *?], Layout[*?, *?], cta_group=config.cta_group], b_loader: TileLoader[b_tma_origin, b_type, Layout[*?, *?], Layout[*?, *?], cta_group=config.cta_group], tiles: ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), IndexList(config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], __list_literal__=NoneType(None)), config.num_pipeline_stages], (config // config), config.k_group_size], iter_idx: UInt32, work_m_coord: Int, work_n_coord: Int, peer_cta_coord: Tuple[Int, Int, Int], elect_one_cta: Bool)

Load k_group_size A and B tiles using 2D TMA (for split-K only).

Orchestrates the tile loading operation including:

  • expect_bytes signaling
  • k-group iteration
  • Peer CTA slicing for 2-SM MMA

Args:

epilogue_load_producer​

static def epilogue_load_producer[_epi_pipeline_stages: Int](epi_load_iter: WorkIterator[config.num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[Int(0)], config.cluster_shape[Int(1)], config.cluster_shape[Int(2)]), config.raster_order, config.block_swizzle_size], mut epilogue_load_pipeline: ProducerConsumerPipeline[_epi_pipeline_stages], epilogue_load_tma_op: TMATensorTile[c_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], bias_1d_tile: TileTensor[c_type, Layout[*?, *?], ImmutAnyOrigin], epilogue_load_tiles: SMemTileArray2DRowMajor[c_type, Int(1) if config.epilogue_is_1d else config.mma_shape[Int(1)] if config.AB_swapped else config.block_tile_shape[Int(0)], config.block_tile_shape[Int(0)] if config.AB_swapped else config.mma_shape[Int(1)] if config.epilogue_is_1d else config.block_tile_shape[Int(0)] if config.AB_swapped else config.output_tile_shape[Int(1)], config.num_accum_pipeline_stages if config.AB_swapped or config.epilogue_is_1d else config.num_tma_epilogue_pipeline_stages if config.use_tma_epilogue_load else Int(0)], mnk: StaticTuple[UInt32, Int(3)])

Load epilogue tiles (bias) from GMEM to SMEM for each output tile.

Handles three cases based on config:

  • 1D bias: warp-wide cp.async with zero-fill for OOB elements
  • AB_swapped: full MMA_N x BM TMA per output tile
  • non-AB_swapped: BM x stageN strips in stage-outer/col_wg-inner order

run​

static def run(a_tma_op: TMATensorTile[a_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Int(3), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(3), Layout[*?, *?]]()], epilogue_load_tma_op: TMATensorTile[c_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], bias_1d_tile: TileTensor[c_type, Layout[*?, *?], ImmutAnyOrigin], cluster_dim: StaticTuple[Int32, Int(3)], mnk: StaticTuple[UInt32, Int(3)], workspace: Span[UInt64, MutAnyOrigin])

Main kernel entry point for SM100 matrix multiplication.

Always uses 3D TMA descriptors. For non-batched inputs, batch=1 and batch_coord=0 (from k_start = block_idx.z = 0 when grid_dim.z = 1). For batched inputs, grid_dim.z = batch_size and batch_coord from k_start.

run_splitk​

static def run_splitk[reduction_layout: TensorLayout](a_tma_op: TMATensorTile[a_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], reduction_tensor: TileTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin], lock_ptr: UnsafePointer[UInt8, MutAnyOrigin], cluster_dim: StaticTuple[Int32, Int(3)], mnk: StaticTuple[UInt32, Int(3)], workspace: Span[UInt64, MutAnyOrigin])

Split-K kernel entry point for better parallelism on small problems.

Split-K divides the K dimension across multiple CTAs, with each CTA computing a partial result that is then reduced.

Args: