Skip to main content

Mojo struct

Grouped1D1DMatmulKernel

struct Grouped1D1DMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, c_device_layout: TensorLayout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], static_N: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(Int32(1)), elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None]

Grouped 1D-1D block-scaled matmul kernel.

Uses 3-warp specialization (Load, MMA, Epilogue) with grid-constant TMAs. Work distribution via GroupedWorkIterator1D1D using offset-based addressing.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_expected_bytes

comptime a_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[a_type]())

a_swizzle_elems

comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())

a_tile_dim0

comptime a_tile_dim0 = compute_tma_tile_dims[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[0]

a_tma_load_size

comptime a_tma_load_size = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0 * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_swizzle_elems)

a_tma_rows

comptime a_tma_rows = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0

accum_pipeline_consumer_arv_count

comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[128, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[1]

accum_pipeline_producer_arv_count

comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[128, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[0]

accum_type

comptime accum_type = DType.float32

ADescLayout

comptime ADescLayout = Layout[*?, *?]

AScaleOffsetsTile

comptime AScaleOffsetsTile = TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin]

ATileLayout

comptime ATileLayout = Layout[*?, *?]

ATmaOp

comptime ATmaOp = TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]

b_expected_bytes

comptime b_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[b_type]())

b_swizzle_elems

comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())

b_tile_dim0

comptime b_tile_dim0 = compute_tma_tile_dims[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[1]

b_tma_load_size

comptime b_tma_load_size = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0 * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_swizzle_elems)

b_tma_rows

comptime b_tma_rows = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0

BDescLayout

comptime BDescLayout = Layout[*?, *?]

BK

comptime BK = config.block_tile_shape[2]

BM

comptime BM = config.block_tile_shape[0]

BN

comptime BN = config.block_tile_shape[1]

BTileLayout

comptime BTileLayout = Layout[*?, *?]

BTmaOp

comptime BTmaOp = TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]

c_desc_dim1

comptime c_desc_dim1 = (config.c_swizzle.bytes() // size_of[c_type]()) if config.AB_swapped else config.output_tile_shape[1] if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems == 0) else Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems

c_swizzle_elems

comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())

c_tile_dim0

comptime c_tile_dim0 = compute_tma_tile_dims[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[2]

c_tile_dim1

comptime c_tile_dim1 = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems if config.AB_swapped else Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].OutputN

CDescLayout

comptime CDescLayout = Layout[*?, *?]

CDeviceTile

comptime CDeviceTile = TileTensor[c_type, c_device_layout, MutAnyOrigin]

CLUSTER_M

comptime CLUSTER_M = config.cluster_shape[0]

CLUSTER_N

comptime CLUSTER_N = config.cluster_shape[1]

CLUSTER_SIZE

comptime CLUSTER_SIZE = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N)

cta_group

comptime cta_group = config.cta_group

CTileLayout

comptime CTileLayout = Layout[*?, *?]

CTmaOp

comptime CTmaOp = TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]

EpilogueCtx

comptime EpilogueCtx = EpilogueWarpContext[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc, 32, 128]

ExpertIdsTile

comptime ExpertIdsTile = TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin]

ExpertScalesTile

comptime ExpertScalesTile = TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin]

input_expected_bytes

comptime input_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group * (((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].a_expected_bytes + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].b_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].sfa_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].sfb_expected_bytes if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N >= 64) else 0)) * config)

InputTilePipelineType

comptime InputTilePipelineType = InputTilePipeline[BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size]

MMA_K

comptime MMA_K = config.mma_shape[2]

MMA_M

comptime MMA_M = config.mma_shape[0]

MMA_N

comptime MMA_N = config.mma_shape[1]

MmaCtx

comptime MmaCtx = MmaWarpContext[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc, 32, 128]

MmaEpilogueSync

comptime MmaEpilogueSync = WarpGroupBarrier[160, 1]

MmaOp

comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

MmaSfbSync

comptime MmaSfbSync = WarpGroupBarrier[160, 2]

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_output_warps

comptime num_output_warps = 4

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS

comptime NUM_THREADS = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].WarpRole.TOTAL_THREADS

NUM_TMEM_COLS

comptime NUM_TMEM_COLS = 512

OffsetsTile

comptime OffsetsTile = TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin]

opc

comptime opc = OutputPipelineConfig(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group)

OutputM

comptime OutputM = config.output_tile_shape[0]

OutputN

comptime OutputN = config.output_tile_shape[1]

OutputPipeline

comptime OutputPipeline = OutputTilePipeline[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc]

sf_atom_u16

comptime sf_atom_u16 = ((((load_from_mem SF_ATOM_M.__getitem_param__[0]()) * (load_from_mem SF_ATOM_M.__getitem_param__[1]())) * 4) // 2)

sf_tma_dtype

comptime sf_tma_dtype = DType.uint16

sfa_expected_bytes

comptime sfa_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.sfa_smem_layout.size() * size_of[sfa_dtype]())

SFA_NUM_COLS

comptime SFA_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM // 32))

SFADescLayout

comptime SFADescLayout = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFATileLayout

SFATileLayout

comptime SFATileLayout = Layout[*?, *?]

SFATmaOp

comptime SFATmaOp = TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]

sfb_atom_u16

comptime sfb_atom_u16 = (((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_TMA_ROWS * (load_from_mem SF_ATOM_M.__getitem_param__[1]())) * 4) // 2)

sfb_expected_bytes

comptime sfb_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.sfb_smem_layout.size() * size_of[sfb_dtype]())

SFB_N_ALIGNED

comptime SFB_N_ALIGNED = align_up(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N, SF_MN_GROUP_SIZE)

SFB_NUM_COLS

comptime SFB_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_N_ALIGNED // 32))

SFB_TMA_K_ATOMS

comptime SFB_TMA_K_ATOMS = 1 if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N < 64) else config.num_sf_k_tiles

SFB_TMA_ROWS

comptime SFB_TMA_ROWS = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N if (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N < (load_from_mem SF_ATOM_M.__getitem_param__[0]())) else (load_from_mem SF_ATOM_M.__getitem_param__[0]())

SFBDescLayout

comptime SFBDescLayout = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFBTileLayout

SFBTileLayout

comptime SFBTileLayout = Layout[*?, *?]

SFBTmaOp

comptime SFBTmaOp = TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]

SmemType

comptime SmemType = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]

stage_stride_cols

comptime stage_stride_cols = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N

TilePayload

comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages]

TileWriterType

comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc, config.c_swizzle, config.AB_swapped, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputN, config.num_output_stages, 4, problem_n=static_N]

Tmem

comptime Tmem = TmemAllocation[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]

TmemDealloc

comptime TmemDealloc = TmemDeallocBarrier[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]

TmemRegion

comptime TmemRegion = BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles, SFB_N=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_N_ALIGNED]

WarpRole

comptime WarpRole = WarpRole1D1D[(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N < 64)]

WorkIterator

comptime WorkIterator = GroupedWorkIterator1D1D[static_N, config.block_tile_shape, config.cluster_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, AB_swapped=config.AB_swapped]

Methods

validate_config

static validate_config()

Compile-time validation of kernel configuration.

init_barriers

static init_barriers(elect_one_warp: Bool, elect_one_thread: Bool, a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfa_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfb_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], input_barriers: SMemArray[SharedMemBarrier, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], tmem_dealloc: SMemArray[SharedMemBarrier, 1])

Initialize barriers and prefetch TMA descriptors.

run

static run(a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfa_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfb_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], a_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], a_scale_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], expert_ids: TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin], expert_scales: TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin], c_device: TileTensor[c_type, c_device_layout, MutAnyOrigin], num_active_experts: Int, K: UInt32, sfb_global_ptr: UnsafePointer[Scalar[sfb_dtype], ImmutAnyOrigin], sfb_n_stride: Int, sfb_k_tiles: Int)

Grouped 1D-1D block-scaled GEMM kernel entry point.

Uses grid-constant TMAs with offset-based addressing for 1D-1D layout.

load_input_tiles

static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfa_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], sfb_tma_op: TMATensorTile[DType.uint16, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], tiles: ProducerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_ctx: GroupedWorkContext1D1D, a_scale_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], iter_idx: UInt32, elect_one_cta: Bool, a_multicast_mask: UInt16, b_multicast_mask: UInt16)

Load A, B, SFA, SFB tiles using TMA.

mma

static mma[tiles_origin: MutOrigin, //](tiles: ConsumerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles, SFB_N=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].SFB_N_ALIGNED], iter_idx: UInt32, k_start: UInt32, sfb_tmem_adj: UInt32)

Execute MMA operations.

For MMA_N >= 64: SFB is loaded to TMEM via tcgen05_cp inside mma_op.mma(). For MMA_N < 64: SFB is pre-loaded by dedicated SFB load warps via tcgen05_st. The MMA warp waits on sfb_load_mbars before entering this function.

epilogue

static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c_device: TileTensor[c_type, c_device_layout, MutAnyOrigin], stage: OutputStage[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn].opc], work_ctx: GroupedWorkContext1D1D)

Execute epilogue to store accumulated results with expert_scale.

Was this page helpful?