Mojo struct
BlackwellMatmulSM100Kernel
struct BlackwellMatmulSM100Kernel[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple(Int32(1)), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = UInt32(0)]
Blackwell SM100 GEMM kernel with warp specialization.
This struct unifies all parameters and derived types for the SM100 matmul kernel, providing:
- Compile-time parameter validation
- Centralized derived type computation
- Factory methods for kernel components
- Multiple kernel entry points (standard, split-k)
The SM100 kernel uses:
- Tensor Memory (TMEM) for MMA accumulators
- Cluster Launch Control (CLC) for dynamic tile scheduling
- Warp specialization: Scheduler, TMA Load, MMA, Epilogue warps
- Software pipelining for overlapping compute and memory operations
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
a_expected_bytesβ
comptime a_expected_bytes = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Layouts.a_tile_elems * size_of[a_type]())
a_smem_layoutβ
comptime a_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // 8)]()), Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // (config.a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (config.a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK == (config.a_swizzle.bytes() // size_of[a_type]())) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM * (config.a_swizzle.bytes() // size_of[a_type]()))]()))).to_layout()
a_swizzle_elemsβ
comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())
a_tile_dim0β
comptime a_tile_dim0 = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)
a_tma_load_sizeβ
comptime a_tma_load_size = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0 * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_swizzle_elems)
a_tma_rowsβ
comptime a_tma_rows = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0
accum_layoutβ
comptime accum_layout = Layout.row_major(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N)
accum_pipeline_consumer_arv_countβ
comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[1]
accum_pipeline_producer_arv_countβ
comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[0]
accum_typeβ
comptime accum_type = MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type
AccumTensorβ
comptime AccumTensor = TmemTensor[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_layout, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]
ADescLayoutβ
comptime ADescLayout = Layout[*?, *?]
ADescLayout_splitkβ
comptime ADescLayout_splitk = Layout[*?, *?]
ATileLayoutβ
comptime ATileLayout = Layout[*?, *?]
ATileLayout_splitkβ
comptime ATileLayout_splitk = Layout[*?, *?]
ATmaOpβ
comptime ATmaOp = TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
ATmaOp_splitkβ
comptime ATmaOp_splitk = TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
b_expected_bytesβ
comptime b_expected_bytes = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Layouts.b_tile_elems * size_of[b_type]())
b_smem_layoutβ
comptime b_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK == (config.b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).to_layout() if transpose_b else Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN == (config.b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).transpose().to_layout()
b_swizzle_elemsβ
comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())
b_tile_dim0β
comptime b_tile_dim0 = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group))
b_tma_load_sizeβ
comptime b_tma_load_size = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0 * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_swizzle_elems)
b_tma_rowsβ
comptime b_tma_rows = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0
BDescLayoutβ
comptime BDescLayout = Layout[*?, *?]
BDescLayout_splitkβ
comptime BDescLayout_splitk = Layout[*?, *?]
BKβ
comptime BK = config.block_tile_shape[2]
BMβ
comptime BM = config.block_tile_shape[0]
BNβ
comptime BN = config.block_tile_shape[1]
BTileLayoutβ
comptime BTileLayout = Layout[*?, *?]
BTileLayout_splitkβ
comptime BTileLayout_splitk = Layout[*?, *?]
BTmaOpβ
comptime BTmaOp = TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
BTmaOp_splitkβ
comptime BTmaOp_splitk = TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
c_swizzle_elemsβ
comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())
c_tile_dim0β
comptime c_tile_dim0 = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputM if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M == 256) if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M == 256) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group == 1) or config.AB_swapped else 64
c_tile_dim1β
comptime c_tile_dim1 = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_swizzle_elems if config.AB_swapped else BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputN
CDescLayoutβ
comptime CDescLayout = Layout[*?, *?]
CDescLayout_splitkβ
comptime CDescLayout_splitk = Layout[*?, *?]
clc_consumer_arv_countβ
comptime clc_consumer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[1]
clc_producer_arv_countβ
comptime clc_producer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[0]
clc_throttle_consumer_arv_countβ
comptime clc_throttle_consumer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[3]
clc_throttle_producer_arv_countβ
comptime clc_throttle_producer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[2]
CLUSTER_Mβ
comptime CLUSTER_M = config.cluster_shape[0]
CLUSTER_Nβ
comptime CLUSTER_N = config.cluster_shape[1]
CLUSTER_SIZEβ
comptime CLUSTER_SIZE = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)
Contextβ
comptime Context = KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]
cta_groupβ
comptime cta_group = config.cta_group
CTileLayoutβ
comptime CTileLayout = Layout[*?, *?]
CTileLayout_splitkβ
comptime CTileLayout_splitk = Layout[*?, *?]
CTmaOpβ
comptime CTmaOp = TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
CTmaOp_splitkβ
comptime CTmaOp_splitk = TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
EPILOGUE_THREADSβ
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtxβ
comptime EpilogueCtx = EpilogueWarpContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]
input_expected_bytesβ
comptime input_expected_bytes = ((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group * (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_expected_bytes + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_expected_bytes)) * config)
InputTilePipelineβ
comptime InputTilePipeline = InputTilePipeline[StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]
MMA_Kβ
comptime MMA_K = config.mma_shape[2]
MMA_Mβ
comptime MMA_M = config.mma_shape[0]
MMA_Nβ
comptime MMA_N = config.mma_shape[1]
MMA_THREADSβ
comptime MMA_THREADS = WARP_SIZE
MmaCtxβ
comptime MmaCtx = MmaWarpContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]
MmaEpilogueSyncβ
comptime MmaEpilogueSync = WarpGroupBarrier[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS), 1]
MmaOpβ
comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, accum_type=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stagesβ
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stagesβ
comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages
num_group_pipeline_stagesβ
comptime num_group_pipeline_stages = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // config)
num_k_mmasβ
comptime num_k_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_K)
num_m_mmasβ
comptime num_m_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group))
num_n_mmasβ
comptime num_n_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group))
num_output_stagesβ
comptime num_output_stages = config.num_output_stages
num_output_warpsβ
comptime num_output_warps = 4
num_pipeline_stagesβ
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADSβ
comptime NUM_THREADS = (((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)
NUM_TMEM_COLSβ
comptime NUM_TMEM_COLS = 512
opcβ
comptime opc = OutputPipelineConfig(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group)
OutputMβ
comptime OutputM = config.output_tile_shape[0]
OutputNβ
comptime OutputN = config.output_tile_shape[1]
OutputPipelineβ
comptime OutputPipeline = OutputTilePipeline[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc]
register_based_epilogueβ
comptime register_based_epilogue = config.register_based_epilogue
Schedulerβ
comptime Scheduler = TileScheduler[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[0], config.cluster_shape[1], config.cluster_shape[2]), config.raster_order, config.block_swizzle_size]
SCHEDULER_THREADSβ
comptime SCHEDULER_THREADS = WARP_SIZE
SmemTypeβ
comptime SmemType = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config]
stage_stride_colsβ
comptime stage_stride_cols = (512 // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages)
TilePayloadβ
comptime TilePayload = StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages]
TileWriterTypeβ
comptime TileWriterType = TileWriter[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, config.block_tile_shape, config.mma_shape, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, config.c_swizzle, config.AB_swapped, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_output_stages, 4, elementwise_lambda_fn, elementwise_compute_lambda_fn, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].register_based_epilogue, True]
TileWriterType_splitkβ
comptime TileWriterType_splitk = TileWriter[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, config.block_tile_shape, config.mma_shape, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, config.c_swizzle, config.AB_swapped, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_output_stages, 4, elementwise_lambda_fn, elementwise_compute_lambda_fn, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].register_based_epilogue]
TMA_LOAD_THREADSβ
comptime TMA_LOAD_THREADS = WARP_SIZE
Tmemβ
comptime Tmem = TmemAllocation[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc.cta_group]
TmemDeallocβ
comptime TmemDealloc = TmemDeallocBarrier[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc.cta_group]
Methodsβ
validate_constraintsβ
static validate_constraints()
Validate parameter constraints at compile time.
init_barriersβ
static init_barriers(ctx: KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N], input_barriers: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages * 2)], clc_full: SMemArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1])
Initialize barriers. TMA descriptor prefetch is done by each kernel entry point before calling this method.
mmaβ
static mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc], tiles: ConsumerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[accum_type=mma_op.accum_type, cta_group=mma_op.cta_group, cluster_shape=mma_op.cluster_shape, a_swizzle=mma_op.a_swizzle, b_swizzle=mma_op.b_swizzle, transpose_b=mma_op.transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)
Execute MMA operations for one pipeline stage.
This is the core MMA function designed to be called within a consumer stage context:
with consumer.acquire() as tiles:
Self.mma(stage.tmem, tiles, mma_op, ...)Args:
- βtmem_stage (
TmemStage[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc]): TMEM stage for accumulators. - βtiles (
ConsumerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]): ConsumerTiles context with encapsulated tile access. - βmma_op (
MmaOpSM100_SS[accum_type=mma_op.accum_type, cta_group=mma_op.cta_group, cluster_shape=mma_op.cluster_shape, a_swizzle=mma_op.a_swizzle, b_swizzle=mma_op.b_swizzle, transpose_b=mma_op.transpose_b]): The MMA operation instance. - βelect_one_warp (
Bool): Whether this warp should execute. - βiter_idx (
UInt32): K iteration index. - βk_start (
UInt32): Starting K iteration (for init_c determination).
load_input_tilesβ
static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], tiles: ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)
Load A and B tiles using 3D TMA.
Uses async_multicast_load_3d with batch coordinate from work_tile_coord[2]. For non-batched calls, batch coord is 0 (grid_dim.z = 1).
Args:
- βa_tma_op (
TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): 3D TMA descriptor for A matrix. - βb_tma_op (
TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): 3D TMA descriptor for B matrix. - βtiles (
ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]): ProducerStage context with encapsulated tile access. - βpeer_cta_coord (
Tuple[Int, Int, Int]): (rank_n, rank_m, peer_m_rank) for peer CTA slicing. - βwork_tile_coord (
Tuple[Int, Int, Int]): (m, n, batch) coordinates. - βa_multicast_mask (
UInt16): Multicast mask for A tiles. - βb_multicast_mask (
UInt16): Multicast mask for B tiles. - βiter_idx (
UInt32): K iteration index (base index for k_group). - βelect_one_cta (
Bool): True if this CTA should call expect_bytes.
prefetch_a_tilesβ
static prefetch_a_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], tiles: ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)
Load A tiles only; set full expected bytes (A+B) on the barrier.
Called before wait_on_dependent_grids() to prefetch the static weight matrix (kernel-A in swapAB mode). The barrier will not fire until the matching complete_b_tiles() call delivers the remaining B bytes.
Args:
- βa_tma_op (
TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): 3D TMA descriptor for A matrix. - βtiles (
ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]): ProducerStage context with encapsulated tile access. - βpeer_cta_coord (
Tuple[Int, Int, Int]): (rank_n, rank_m, peer_m_rank) for peer CTA slicing. - βwork_tile_coord (
Tuple[Int, Int, Int]): (m, n, batch) coordinates. - βa_multicast_mask (
UInt16): Multicast mask for A tiles. - βiter_idx (
UInt32): K iteration index (base index for k_group). - βelect_one_cta (
Bool): True if this CTA should call expect_bytes.
complete_b_tilesβ
static complete_b_tiles(b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], stage: UInt32, barrier: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], payload: StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], b_multicast_mask: UInt16, iter_idx: UInt32)
Load B tiles into a previously prefetched stage.
Delivers the remaining B bytes so that the stage barrier fires and the consumer can proceed. Pair with prefetch_a_tiles().
Args:
- βb_tma_op (
TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): 3D TMA descriptor for B matrix. - βstage (
UInt32): Stage index saved from the prefetch phase. - βbarrier (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): Barrier pointer saved from the prefetch phase. - βpayload (
StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages]): Tile payload from the pipeline (gives smem pointers). - βpeer_cta_coord (
Tuple[Int, Int, Int]): (rank_n, rank_m, peer_m_rank) for peer CTA slicing. - βwork_tile_coord (
Tuple[Int, Int, Int]): (m, n, batch) coordinates. - βb_multicast_mask (
UInt16): Multicast mask for B tiles. - βiter_idx (
UInt32): K iteration index (base index for k_group).
load_input_tiles_splitkβ
static load_input_tiles_splitk[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](a_loader: TileLoader[a_tma_origin, a_type, Layout[*?, *?], Layout[*?, *?], cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group], b_loader: TileLoader[b_tma_origin, b_type, Layout[*?, *?], Layout[*?, *?], cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group], tiles: ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], iter_idx: UInt32, work_m_coord: Int, work_n_coord: Int, peer_cta_coord: Tuple[Int, Int, Int], elect_one_cta: Bool)
Load k_group_size A and B tiles using 2D TMA (for split-K only).
Orchestrates the tile loading operation including:
- expect_bytes signaling
- k-group iteration
- Peer CTA slicing for 2-SM MMA
Args:
- βa_loader (
TileLoader[a_tma_origin, a_type, Layout[*?, *?], Layout[*?, *?], cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]): TileLoader for A matrix (2D). - βb_loader (
TileLoader[b_tma_origin, b_type, Layout[*?, *?], Layout[*?, *?], cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]): TileLoader for B matrix (2D). - βtiles (
ProducerTiles[tiles_origin, StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=NoneType(None)), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages], BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]): ProducerTiles context with encapsulated tile access. - βiter_idx (
UInt32): K iteration index (base index). - βwork_m_coord (
Int): M coordinate of the output tile. - βwork_n_coord (
Int): N coordinate of the output tile. - βpeer_cta_coord (
Tuple[Int, Int, Int]): Peer CTA coordinates (rank_n, rank_m, peer_m_rank). - βelect_one_cta (
Bool): True if this CTA should call expect_bytes.
runβ
static run(a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])
Main kernel entry point for SM100 matrix multiplication.
Always uses 3D TMA descriptors. For non-batched inputs, batch=1 and batch_coord=0 (from k_start = block_idx.z = 0 when grid_dim.z = 1). For batched inputs, grid_dim.z = batch_size and batch_coord from k_start.
run_splitkβ
static run_splitk[reduction_layout: TensorLayout](a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], reduction_tensor: TileTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin], lock_ptr: UnsafePointer[UInt8, MutAnyOrigin], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])
Split-K kernel entry point for better parallelism on small problems.
Split-K divides the K dimension across multiple CTAs, with each CTA computing a partial result that is then reduced.
Args:
- βa_tma_op (
TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for matrix A. - βb_tma_op (
TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for matrix B. - βc_tma_op (
TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for matrix C. - βreduction_tensor (
TileTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin]): Workspace for partial results from each split. - βlock_ptr (
UnsafePointer[UInt8, MutAnyOrigin]): Synchronization locks for reduction coordination. - βcluster_dim (
StaticTuple[Int32, 3]): Cluster dimensions. - βmnk (
StaticTuple[UInt32, 3]): Problem dimensions (M, N, K). - βworkspace (
Span[UInt64, MutAnyOrigin]): Workspace buffer for profiling/scheduling.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!