Skip to main content

Mojo struct

BlackwellMatmulSM100Kernel

struct BlackwellMatmulSM100Kernel[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple(SIMD(1)), elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]

Blackwell SM100 GEMM kernel with warp specialization.

This struct unifies all parameters and derived types for the SM100 matmul kernel, providing:

  • Compile-time parameter validation
  • Centralized derived type computation
  • Factory methods for kernel components
  • Multiple kernel entry points (standard, split-k)

The SM100 kernel uses:

  • Tensor Memory (TMEM) for MMA accumulators
  • Cluster Launch Control (CLC) for dynamic tile scheduling
  • Warp specialization: Scheduler, TMA Load, MMA, Epilogue warps
  • Software pipelining for overlapping compute and memory operations

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_expected_bytes

comptime a_expected_bytes = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Layouts.a_tile_elems * size_of[a_type]())

a_smem_layout

comptime a_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // 8)]()), Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // (config.a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (config.a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK == (config.a_swizzle.bytes() // size_of[a_type]())) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM * (config.a_swizzle.bytes() // size_of[a_type]()))]()))).to_layout()

a_swizzle_elems

comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())

a_tile_dim0

comptime a_tile_dim0 = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)

a_tma_load_size

comptime a_tma_load_size = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0 * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_swizzle_elems)

a_tma_rows

comptime a_tma_rows = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0

accum_layout

comptime accum_layout = Layout.row_major(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N)

accum_pipeline_consumer_arv_count

comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[1]

accum_pipeline_producer_arv_count

comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[0]

accum_type

comptime accum_type = MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type

AccumTensor

comptime AccumTensor = TmemTensor[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_layout, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]

ADescLayout

comptime ADescLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

ADescLayout_splitk

comptime ADescLayout_splitk = Layout[ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_swizzle_elems], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_swizzle_elems], ComptimeInt[1]]

ATileLayout

comptime ATileLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

ATileLayout_splitk

comptime ATileLayout_splitk = Layout[ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[1]]

ATmaOp

comptime ATmaOp = TMATensorTile[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout]()]

ATmaOp_splitk

comptime ATmaOp_splitk = TMATensorTile[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout_splitk.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout_splitk](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout_splitk.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout_splitk]()]

b_expected_bytes

comptime b_expected_bytes = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Layouts.b_tile_elems * size_of[b_type]())

b_smem_layout

comptime b_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK == (config.b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).to_layout() if transpose_b else Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN == (config.b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).transpose().to_layout()

b_swizzle_elems

comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())

b_tile_dim0

comptime b_tile_dim0 = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group))

b_tma_load_size

comptime b_tma_load_size = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0 * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_swizzle_elems)

b_tma_rows

comptime b_tma_rows = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0

BDescLayout

comptime BDescLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

BDescLayout_splitk

comptime BDescLayout_splitk = Layout[ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_swizzle_elems], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_swizzle_elems], ComptimeInt[1]]

BK

comptime BK = config.block_tile_shape[2]

BM

comptime BM = config.block_tile_shape[0]

BN

comptime BN = config.block_tile_shape[1]

BTileLayout

comptime BTileLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

BTileLayout_splitk

comptime BTileLayout_splitk = Layout[ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[1]]

BTmaOp

comptime BTmaOp = TMATensorTile[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout]()]

BTmaOp_splitk

comptime BTmaOp_splitk = TMATensorTile[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout_splitk.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout_splitk](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout_splitk.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout_splitk]()]

c_swizzle_elems

comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())

c_tile_dim0

comptime c_tile_dim0 = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputM if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M == 256) if (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M == 256) else (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group == 1) or config.AB_swapped else 64

c_tile_dim1

comptime c_tile_dim1 = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_swizzle_elems if config.AB_swapped else BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputN

CDescLayout

comptime CDescLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

CDescLayout_splitk

comptime CDescLayout_splitk = Layout[ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_swizzle_elems], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_swizzle_elems], ComptimeInt[1]]

clc_consumer_arv_count

comptime clc_consumer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[1]

clc_producer_arv_count

comptime clc_producer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[0]

clc_throttle_consumer_arv_count

comptime clc_throttle_consumer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[3]

clc_throttle_producer_arv_count

comptime clc_throttle_producer_arv_count = compute_clc_barrier_counts[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[2]

CLUSTER_M

comptime CLUSTER_M = config.cluster_shape[0]

CLUSTER_N

comptime CLUSTER_N = config.cluster_shape[1]

CLUSTER_SIZE

comptime CLUSTER_SIZE = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)

Context

comptime Context = KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]

cta_group

comptime cta_group = config.cta_group

CTileLayout

comptime CTileLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1], ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

CTileLayout_splitk

comptime CTileLayout_splitk = Layout[ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1], ComptimeInt[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1], ComptimeInt[1]]

CTmaOp

comptime CTmaOp = TMATensorTile[c_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout]()]

CTmaOp_splitk

comptime CTmaOp_splitk = TMATensorTile[c_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout_splitk.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout_splitk](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout_splitk.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout_splitk]()]

EPILOGUE_THREADS

comptime EPILOGUE_THREADS = (4 * WARP_SIZE)

EpilogueCtx

comptime EpilogueCtx = EpilogueWarpContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]

input_expected_bytes

comptime input_expected_bytes = ((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group * (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_expected_bytes + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_expected_bytes)) * config)

InputTilePipeline

comptime InputTilePipeline = InputTilePipeline[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size]

MMA_K

comptime MMA_K = config.mma_shape[2]

MMA_M

comptime MMA_M = config.mma_shape[0]

MMA_N

comptime MMA_N = config.mma_shape[1]

MMA_THREADS

comptime MMA_THREADS = WARP_SIZE

MmaCtx

comptime MmaCtx = MmaWarpContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]

MmaEpilogueSync

comptime MmaEpilogueSync = WarpGroupBarrier[(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS), 1]

MmaOp

comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, accum_type=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // config)

num_k_mmas

comptime num_k_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_K)

num_m_mmas

comptime num_m_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group))

num_n_mmas

comptime num_n_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group))

num_output_stages

comptime num_output_stages = config.num_output_stages

num_output_warps

comptime num_output_warps = 4

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS

comptime NUM_THREADS = (((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)

NUM_TMEM_COLS

comptime NUM_TMEM_COLS = 512

opc

comptime opc = OutputPipelineConfig(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group)

OutputM

comptime OutputM = config.output_tile_shape[0]

OutputN

comptime OutputN = config.output_tile_shape[1]

OutputPipeline

comptime OutputPipeline = OutputTilePipeline[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc]

register_based_epilogue

comptime register_based_epilogue = config.register_based_epilogue

Scheduler

comptime Scheduler = TileScheduler[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[0], config.cluster_shape[1], config.cluster_shape[2]), config.raster_order, config.block_swizzle_size]

SCHEDULER_THREADS

comptime SCHEDULER_THREADS = WARP_SIZE

SmemType

comptime SmemType = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config]

stage_stride_cols

comptime stage_stride_cols = (512 // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages)

TilePayload

comptime TilePayload = StandardTilePayload[a_type, b_type, IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=Tuple()), IndexList(BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK, __list_literal__=Tuple()), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages]

TileWriterType

comptime TileWriterType = TileWriter[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, config.block_tile_shape, config.mma_shape, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, config.c_swizzle, config.AB_swapped, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_output_stages, 4, elementwise_lambda_fn, elementwise_compute_lambda_fn, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].register_based_epilogue, True]

TileWriterType_splitk

comptime TileWriterType_splitk = TileWriter[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].accum_type, config.block_tile_shape, config.mma_shape, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, config.c_swizzle, config.AB_swapped, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.OutputN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_output_stages, 4, elementwise_lambda_fn, elementwise_compute_lambda_fn, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].register_based_epilogue]

TMA_LOAD_THREADS

comptime TMA_LOAD_THREADS = WARP_SIZE

Tmem

comptime Tmem = TmemAllocation[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc.cta_group]

TmemDealloc

comptime TmemDealloc = TmemDeallocBarrier[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc.cta_group]

Methods

validate_constraints

static validate_constraints()

Validate parameter constraints at compile time.

init_barriers

static init_barriers(ctx: KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N], input_barriers: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages * 2)], clc_full: SMemArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1])

Initialize barriers. TMA descriptor prefetch is done by each kernel entry point before calling this method.

mma

static mma[tiles_origin: MutOrigin, //](tmem_stage: TmemStage[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc], tiles: ConsumerTiles[tiles_origin, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[mma_op.c_type, mma_op.a_type, mma_op.b_type, mma_op.block_tile_shape, mma_op.mma_shape, accum_type=mma_op.accum_type, cta_group=mma_op.cta_group, cluster_shape=mma_op.cluster_shape, a_swizzle=mma_op.a_swizzle, b_swizzle=mma_op.b_swizzle, transpose_b=mma_op.transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)

Execute MMA operations for one pipeline stage.

This is the core MMA function designed to be called within a consumer stage context:

with consumer.acquire() as tiles:
    Self.mma(stage.tmem, tiles, mma_op, ...)

Args:

  • tmem_stage (TmemStage): TMEM stage for accumulators.
  • tiles (ConsumerTiles): ConsumerTiles context with encapsulated tile access.
  • mma_op (MmaOpSM100_SS): The MMA operation instance.
  • elect_one_warp (Bool): Whether this warp should execute.
  • iter_idx (UInt32): K iteration index.
  • k_start (UInt32): Starting K iteration (for init_c determination).

load_input_tiles

static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout]()], tiles: ProducerTiles[tiles_origin, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)

Load A and B tiles using 3D TMA.

Uses async_multicast_load_3d with batch coordinate from work_tile_coord[2]. For non-batched calls, batch coord is 0 (grid_dim.z = 1).

Args:

  • a_tma_op (TMATensorTile): 3D TMA descriptor for A matrix.
  • b_tma_op (TMATensorTile): 3D TMA descriptor for B matrix.
  • tiles (ProducerTiles): ProducerStage context with encapsulated tile access.
  • peer_cta_coord (Tuple): (rank_n, rank_m, peer_m_rank) for peer CTA slicing.
  • work_tile_coord (Tuple): (m, n, batch) coordinates.
  • a_multicast_mask (UInt16): Multicast mask for A tiles.
  • b_multicast_mask (UInt16): Multicast mask for B tiles.
  • iter_idx (UInt32): K iteration index (base index for k_group).
  • elect_one_cta (Bool): True if this CTA should call expect_bytes.

load_input_tiles_splitk

static load_input_tiles_splitk[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](a_loader: TileLoader[a_tma_origin, a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout_splitk, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout_splitk, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group], b_loader: TileLoader[b_tma_origin, b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout_splitk, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout_splitk, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group], tiles: ProducerTiles[tiles_origin, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, config.k_group_size], iter_idx: UInt32, work_m_coord: Int, work_n_coord: Int, peer_cta_coord: Tuple[Int, Int, Int], elect_one_cta: Bool)

Load k_group_size A and B tiles using 2D TMA (for split-K only).

Orchestrates the tile loading operation including:

  • expect_bytes signaling
  • k-group iteration
  • Peer CTA slicing for 2-SM MMA

Args:

  • a_loader (TileLoader): TileLoader for A matrix (2D).
  • b_loader (TileLoader): TileLoader for B matrix (2D).
  • tiles (ProducerTiles): ProducerTiles context with encapsulated tile access.
  • iter_idx (UInt32): K iteration index (base index).
  • work_m_coord (Int): M coordinate of the output tile.
  • work_n_coord (Int): N coordinate of the output tile.
  • peer_cta_coord (Tuple): Peer CTA coordinates (rank_n, rank_m, peer_m_rank).
  • elect_one_cta (Bool): True if this CTA should call expect_bytes.

run

static run(a_tma_op: TMATensorTile[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout]()], c_tma_op: TMATensorTile[c_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout]()], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])

Main kernel entry point for SM100 matrix multiplication.

Always uses 3D TMA descriptors. For non-batched inputs, batch=1 and batch_coord=0 (from k_start = block_idx.z = 0 when grid_dim.z = 1). For batched inputs, grid_dim.z = batch_size and batch_coord from k_start.

run_splitk

static run_splitk[reduction_layout: TensorLayout](a_tma_op: TMATensorTile[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout_splitk.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout_splitk](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout_splitk.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout_splitk]()], b_tma_op: TMATensorTile[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout_splitk.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout_splitk](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout_splitk.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout_splitk]()], c_tma_op: TMATensorTile[c_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout_splitk.rank, _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout_splitk](), _to_index_list[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout_splitk.rank, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, transpose_b, config, cluster_shape, elementwise_lambda_fn, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout_splitk]()], reduction_tensor: TileTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin], lock_ptr: UnsafePointer[UInt8, MutAnyOrigin], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])

Split-K kernel entry point for better parallelism on small problems.

Split-K divides the K dimension across multiple CTAs, with each CTA computing a partial result that is then reduced.

Args:

  • a_tma_op (TMATensorTile): TMA descriptor for matrix A.
  • b_tma_op (TMATensorTile): TMA descriptor for matrix B.
  • c_tma_op (TMATensorTile): TMA descriptor for matrix C.
  • reduction_tensor (TileTensor): Workspace for partial results from each split.
  • lock_ptr (UnsafePointer): Synchronization locks for reduction coordination.
  • cluster_dim (StaticTuple): Cluster dimensions.
  • mnk (StaticTuple): Problem dimensions (M, N, K).
  • workspace (Span): Workspace buffer for profiling/scheduling.

Was this page helpful?