Mojo struct
Grouped1D1DMatmulKernel
struct Grouped1D1DMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, a_scale_offsets_layout: Layout, c_device_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], static_N: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True]
Grouped 1D-1D block-scaled matmul kernel.
Uses 3-warp specialization (Load, MMA, Epilogue) with grid-constant TMAs. Work distribution via GroupedWorkIterator1D1D using offset-based addressing.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_expected_bytes
comptime a_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.a_smem_layout.size() * size_of[a_type]())
a_tma_load_size
comptime a_tma_load_size = a_desc_layout.size()
a_tma_rows
comptime a_tma_rows = a_desc_layout.shape[1].value()
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * 128)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = 1
accum_type
comptime accum_type = DType.float32
b_expected_bytes
comptime b_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.b_smem_layout.size() * size_of[b_type]())
b_tma_load_size
comptime b_tma_load_size = b_desc_layout.size()
b_tma_rows
comptime b_tma_rows = b_desc_layout.shape[1].value()
BK
comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)
CLUSTER_M
comptime CLUSTER_M = config.cluster_shape.__getitem__[3, DType.int64, Int](0)
CLUSTER_N
comptime CLUSTER_N = config.cluster_shape.__getitem__[3, DType.int64, Int](1)
CLUSTER_SIZE
comptime CLUSTER_SIZE = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M * Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)
cta_group
comptime cta_group = config.cta_group
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, 32, 128]
input_expected_bytes
comptime input_expected_bytes = ((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * (((Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_expected_bytes + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_expected_bytes) + Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_expected_bytes)) * config)
InputTilePipelineType
comptime InputTilePipelineType = InputTilePipeline[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size]
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
MmaCtx
comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, 32, 128]
MmaEpilogueSync
comptime MmaEpilogueSync = WarpGroupBarrier[160, 1]
MmaOp
comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_group_pipeline_stages
comptime num_group_pipeline_stages = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages // config)
num_output_stages
comptime num_output_stages = config.num_output_stages
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADS
comptime NUM_THREADS = WarpRole1D1D.TOTAL_THREADS
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
sfa_expected_bytes
comptime sfa_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.sfa_smem_layout.size() * size_of[sfa_dtype]())
SFA_NUM_COLS
comptime SFA_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // 32))
sfb_expected_bytes
comptime sfb_expected_bytes = (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.sfb_smem_layout.size() * size_of[sfb_dtype]())
SFB_NUM_COLS
comptime SFB_NUM_COLS = (config * (Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // 32))
SmemType
comptime SmemType = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N
TilePayload
comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BN, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM1, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM0, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM1, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages]
TileWriterType
comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, config.num_output_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, 4]
Tmem
comptime Tmem = TmemAllocation[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
TmemDealloc
comptime TmemDealloc = TmemDeallocBarrier[Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
TmemRegion
comptime TmemRegion = BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]
WorkIterator
comptime WorkIterator = GroupedWorkIterator1D1D[offsets_layout, expert_ids_layout, expert_scales_layout, static_N, config.block_tile_shape, config.cluster_shape, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
Methods
validate_config
static validate_config()
Compile-time validation of kernel configuration.
run
static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], a_offsets: LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], c_device: LayoutTensor[c_type, c_device_layout, MutAnyOrigin], num_active_experts: Int, K: UInt32)
Grouped 1D-1D block-scaled GEMM kernel entry point.
Uses grid-constant TMAs with offset-based addressing for 1D-1D layout.
load_input_tiles
static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], tiles: InputProducerStage[tiles_origin, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_ctx: GroupedWorkContext1D1D, a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], iter_idx: UInt32, elect_one_cta: Bool)
Load A, B, SFA, SFB tiles using TMA.
mma
static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)
Execute MMA operations.
epilogue
static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_device: LayoutTensor[c_type, c_device_layout, MutAnyOrigin], stage: OutputStage[config.num_accum_pipeline_stages, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, Grouped1D1DMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, offsets_layout, expert_ids_layout, expert_scales_layout, a_scale_offsets_layout, c_device_layout, transpose_b, config, static_N, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], work_ctx: GroupedWorkContext1D1D)
Execute epilogue to store accumulated results with expert_scale.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!