Mojo struct
Grouped1D1DMatmulKernel
struct Grouped1D1DMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, c_device_layout: TensorLayout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], static_N: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(Int32(1)), elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None]
Grouped 1D-1D block-scaled matmul kernel.
Uses 3-warp specialization (Load, MMA, Epilogue) with grid-constant TMAs. Work distribution via GroupedWorkIterator1D1D using offset-based addressing.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
a_expected_bytes
comptime a_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[a_type]())
a_swizzle_elems
comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())
a_tile_dim0
comptime a_tile_dim0 = compute_tma_tile_dims[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[0]
a_tma_load_size
comptime a_tma_load_size = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0 * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_swizzle_elems)
a_tma_rows
comptime a_tma_rows = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[128, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[1]
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[128, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[0]
accum_type
comptime accum_type = DType.float32
ADescLayout
comptime ADescLayout = Layout[*?, *?]
AScaleOffsetsTile
comptime AScaleOffsetsTile = TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin]
ATileLayout
comptime ATileLayout = Layout[*?, *?]
ATmaOp
comptime ATmaOp = TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
b_expected_bytes
comptime b_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[b_type]())
b_swizzle_elems
comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())
b_tile_dim0
comptime b_tile_dim0 = compute_tma_tile_dims[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[1]
b_tma_load_size
comptime b_tma_load_size = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0 * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_swizzle_elems)
b_tma_rows
comptime b_tma_rows = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0
BDescLayout
comptime BDescLayout = Layout[*?, *?]
BK
comptime BK = config.block_tile_shape[2]
BM
comptime BM = config.block_tile_shape[0]
BN
comptime BN = config.block_tile_shape[1]
BTileLayout
comptime BTileLayout = Layout[*?, *?]
BTmaOp
comptime BTmaOp = TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
c_desc_dim1
comptime c_desc_dim1 = (config.c_swizzle.bytes() // size_of[c_type]()) if config.AB_swapped else config.output_tile_shape[1] if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems == 0) else Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems
c_swizzle_elems
comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())
c_tile_dim0
comptime c_tile_dim0 = compute_tma_tile_dims[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[2]
c_tile_dim1
comptime c_tile_dim1 = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems if config.AB_swapped else Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputN
CDescLayout
comptime CDescLayout = Layout[*?, *?]
CDeviceTile
comptime CDeviceTile = TileTensor[c_type, c_device_layout, MutAnyOrigin]
CLUSTER_M
comptime CLUSTER_M = config.cluster_shape[0]
CLUSTER_N
comptime CLUSTER_N = config.cluster_shape[1]
CLUSTER_SIZE
comptime CLUSTER_SIZE = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N)
cta_group
comptime cta_group = config.cta_group
CTileLayout
comptime CTileLayout = Layout[*?, *?]
CTmaOp
comptime CTmaOp = TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc, 32, 128]
ExpertIdsTile
comptime ExpertIdsTile = TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin]
ExpertScalesTile
comptime ExpertScalesTile = TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin]
input_expected_bytes
comptime input_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group * (((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_expected_bytes + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].sfa_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].sfb_expected_bytes if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N >= 64) else 0)) * config)
InputTilePipelineType
comptime InputTilePipelineType = InputTilePipeline[BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size]
MMA_K
comptime MMA_K = config.mma_shape[2]
MMA_M
comptime MMA_M = config.mma_shape[0]
MMA_N
comptime MMA_N = config.mma_shape[1]
MmaCtx
comptime MmaCtx = MmaWarpContext[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc, 32, 128]
MmaEpilogueSync
comptime MmaEpilogueSync = WarpGroupBarrier[160, 1]
MmaOp
comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
MmaSfbSync
comptime MmaSfbSync = WarpGroupBarrier[160, 2]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_group_pipeline_stages
comptime num_group_pipeline_stages = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages // config)
num_output_stages
comptime num_output_stages = config.num_output_stages
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADS
comptime NUM_THREADS = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].WarpRole.TOTAL_THREADS
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OffsetsTile
comptime OffsetsTile = TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin]
opc
comptime opc = OutputPipelineConfig(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group)
OutputM
comptime OutputM = config.output_tile_shape[0]
OutputN
comptime OutputN = config.output_tile_shape[1]
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc]
sf_atom_u16
comptime sf_atom_u16 = ((((load_from_mem SF_ATOM_M.__getitem_param__[0]()) * (load_from_mem SF_ATOM_M.__getitem_param__[1]())) * 4) // 2)
sf_tma_dtype
comptime sf_tma_dtype = DType.uint16
sfa_expected_bytes
comptime sfa_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.sfa_smem_layout.size() * size_of[sfa_dtype]())
SFA_NUM_COLS
comptime SFA_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM // 32))
SFADescLayout
comptime SFADescLayout = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFATileLayout
SFATileLayout
comptime SFATileLayout = Layout[*?, *?]
SFATmaOp
comptime SFATmaOp = TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
sfb_atom_u16
comptime sfb_atom_u16 = (((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_TMA_ROWS * (load_from_mem SF_ATOM_M.__getitem_param__[1]())) * 4) // 2)
sfb_expected_bytes
comptime sfb_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.sfb_smem_layout.size() * size_of[sfb_dtype]())
SFB_N_ALIGNED
comptime SFB_N_ALIGNED = align_up(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N, SF_MN_GROUP_SIZE)
SFB_NUM_COLS
comptime SFB_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_N_ALIGNED // 32))
SFB_TMA_K_ATOMS
comptime SFB_TMA_K_ATOMS = 1 if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N < 64) else config.num_sf_k_tiles
SFB_TMA_ROWS
comptime SFB_TMA_ROWS = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N < (load_from_mem SF_ATOM_M.__getitem_param__[0]())) else (load_from_mem SF_ATOM_M.__getitem_param__[0]())
SFBDescLayout
comptime SFBDescLayout = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFBTileLayout
SFBTileLayout
comptime SFBTileLayout = Layout[*?, *?]
SFBTmaOp
comptime SFBTmaOp = TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
SmemType
comptime SmemType = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N
TilePayload
comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages]
TileWriterType
comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc, config.c_swizzle, config.AB_swapped, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputN, config.num_output_stages, 4, problem_n=static_N]
Tmem
comptime Tmem = TmemAllocation[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]
TmemDealloc
comptime TmemDealloc = TmemDeallocBarrier[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]
TmemRegion
comptime TmemRegion = BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles, SFB_N=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_N_ALIGNED]
WarpRole
comptime WarpRole = WarpRole1D1D[(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N < 64)]
WorkIterator
comptime WorkIterator = GroupedWorkIterator1D1D[static_N, config.block_tile_shape, config.cluster_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, AB_swapped=config.AB_swapped]
Methods
validate_config
static validate_config()
Compile-time validation of kernel configuration.
init_barriers
static init_barriers(elect_one_warp: Bool, elect_one_thread: Bool, a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfa_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfb_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], input_barriers: SMemArray[SharedMemBarrier, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], tmem_dealloc: SMemArray[SharedMemBarrier, 1])
Initialize barriers and prefetch TMA descriptors.
run
static run(a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfa_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfb_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], a_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], a_scale_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], expert_ids: TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin], expert_scales: TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin], c_device: TileTensor[c_type, c_device_layout, MutAnyOrigin], num_active_experts: Int, K: UInt32, sfb_global_ptr: UnsafePointer[Scalar[sfb_dtype], ImmutAnyOrigin], sfb_n_stride: Int, sfb_k_tiles: Int)
Grouped 1D-1D block-scaled GEMM kernel entry point.
Uses grid-constant TMAs with offset-based addressing for 1D-1D layout.
load_input_tiles
static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfa_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfb_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], tiles: ProducerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_ctx: GroupedWorkContext1D1D, a_scale_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], iter_idx: UInt32, elect_one_cta: Bool, a_multicast_mask: UInt16, b_multicast_mask: UInt16)
Load A, B, SFA, SFB tiles using TMA.
mma
static mma[tiles_origin: MutOrigin, //](tiles: ConsumerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles, SFB_N=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_N_ALIGNED], iter_idx: UInt32, k_start: UInt32, sfb_tmem_adj: UInt32)
Execute MMA operations.
For MMA_N >= 64: SFB is loaded to TMEM via tcgen05_cp inside mma_op.mma(). For MMA_N < 64: SFB is pre-loaded by dedicated SFB load warps via tcgen05_st. The MMA warp waits on sfb_load_mbars before entering this function.
epilogue
static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_device: TileTensor[c_type, c_device_layout, MutAnyOrigin], stage: OutputStage[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc], work_ctx: GroupedWorkContext1D1D)
Execute epilogue to store accumulated results with expert_scale.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!