Skip to main content

Mojo struct

Grouped1D1DMatmulKernel

struct Grouped1D1DMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, a_scale_offsets_layout: Layout, c_device_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], static_N: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True]

Grouped 1D-1D block-scaled matmul kernel.

Uses 3-warp specialization (Load, MMA, Epilogue) with grid-constant TMAs. Work distribution via GroupedWorkIterator1D1D using offset-based addressing.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

a_expected_bytes

comptime a_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.a_smem_layout.size() * size_of[a_type]())

a_tma_load_size

comptime a_tma_load_size = a_desc_layout.size()

a_tma_rows

comptime a_tma_rows = a_desc_layout.shape[1].value()

accum_pipeline_consumer_arv_count

comptime accum_pipeline_consumer_arv_count = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * 128)

accum_pipeline_producer_arv_count

comptime accum_pipeline_producer_arv_count = 1

accum_type

comptime accum_type = DType.float32

b_expected_bytes

comptime b_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.b_smem_layout.size() * size_of[b_type]())

b_tma_load_size

comptime b_tma_load_size = b_desc_layout.size()

b_tma_rows

comptime b_tma_rows = b_desc_layout.shape[1].value()

BK

comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)

CLUSTER_M

comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)

CLUSTER_N

comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)

CLUSTER_SIZE

comptime CLUSTER_SIZE = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)

cta_group

comptime cta_group = config.cta_group

EpilogueCtx

comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, 32, 128]

input_expected_bytes

comptime input_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * (((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_expected_bytes + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_expected_bytes)) * config)

InputTilePipelineType

comptime InputTilePipelineType = InputTilePipeline[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size]

MMA_K

comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)

MMA_M

comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)

MmaCtx

comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, 32, 128]

MmaEpilogueSync

comptime MmaEpilogueSync = WarpGroupBarrier[160, 1]

MmaOp

comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_output_warps

comptime num_output_warps = 4

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS

comptime NUM_THREADS = WarpRole1D1D.TOTAL_THREADS

NUM_TMEM_COLS

comptime NUM_TMEM_COLS = 512

OutputM

comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)

OutputN

comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)

OutputPipeline

comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]

sfa_expected_bytes

comptime sfa_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.sfa_smem_layout.size() * size_of[sfa_dtype]())

SFA_NUM_COLS

comptime SFA_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // 32))

sfb_expected_bytes

comptime sfb_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.sfb_smem_layout.size() * size_of[sfb_dtype]())

SFB_NUM_COLS

comptime SFB_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // 32))

SmemType

comptime SmemType = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]

stage_stride_cols

comptime stage_stride_cols = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N

TilePayload

comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM1, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM1, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages]

TileWriterType

comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, config.num_output_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, 4]

Tmem

comptime Tmem = TmemAllocation[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]

TmemDealloc

comptime TmemDealloc = TmemDeallocBarrier[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]

TmemRegion

comptime TmemRegion = BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]

WorkIterator

comptime WorkIterator = GroupedWorkIterator1D1D[offsets_layout, expert_ids_layout, expert_scales_layout, static_N, config.block_tile_shape, config.cluster_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]

Methods

validate_config

static validate_config()

Compile-time validation of kernel configuration.

run

static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], a_offsets: LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], c_device: LayoutTensor[c_type, c_device_layout, MutAnyOrigin], num_active_experts: Int, K: UInt32)

Grouped 1D-1D block-scaled GEMM kernel entry point.

Uses grid-constant TMAs with offset-based addressing for 1D-1D layout.

load_input_tiles

static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], tiles: InputProducerStage[tiles_origin, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_ctx: GroupedWorkContext1D1D, a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], iter_idx: UInt32, elect_one_cta: Bool)

Load A, B, SFA, SFB tiles using TMA.

mma

static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)

Execute MMA operations.

epilogue

static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_device: LayoutTensor[c_type, c_device_layout, MutAnyOrigin], stage: OutputStage[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], work_ctx: GroupedWorkContext1D1D)

Execute epilogue to store accumulated results with expert_scale.

Was this page helpful?