Mojo struct
GroupedBlockScaledSmem
struct GroupedBlockScaledSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]
SMEM struct for grouped block-scaled GEMM.
Extends standard BlockScaledSmem with:
- 5 TMA descriptor slots for dynamic tensormap updates (A, B, SFA, SFB, C)
- Each descriptor is 128 bytes with 128-byte alignment
Layout in SMEM:
- Tensormap descriptors (5 x 128 bytes = 640 bytes)
- A tiles
- B tiles
- C tiles
- SFA tiles
- SFB tiles
- Pipeline barriers
- CLC barriers
- TMEM state
Fields
- tiles (
GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles): - pipelines (
GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Pipelines): - tensormap_a (
TMADescriptor): - tensormap_b (
TMADescriptor): - tensormap_sfa (
TMADescriptor): - tensormap_sfb (
TMADescriptor): - tensormap_c (
TMADescriptor):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_smem_layout
comptime a_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.a_smem_layout
ATileArray
comptime ATileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.ATileArray
b_smem_layout
comptime b_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.b_smem_layout
BK
comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)
BTileArray
comptime BTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.BTileArray
c_smem_layout
comptime c_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.c_smem_layout
CTileArray
comptime CTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.CTileArray
Layouts
comptime Layouts = SmemLayouts[a_type, b_type, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stages
comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages
num_group_pipeline_stages
comptime num_group_pipeline_stages = (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // config)
num_output_stages
comptime num_output_stages = config.num_output_stages
num_pipeline_stages
comptime num_pipeline_stages = config.num_pipeline_stages
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)
Pipelines
comptime Pipelines = SmemPipelineBundle[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages]]
SF_BK
comptime SF_BK = sf_bk[config]()
SF_K_GROUP_SIZE
comptime SF_K_GROUP_SIZE = sf_k_group_size[config]()
SFA_DIM0
comptime SFA_DIM0 = sfa_dim0[config]()
SFA_DIM1
comptime SFA_DIM1 = sfa_dim1[config]()
sfa_smem_layout
comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
SFATileArray
comptime SFATileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFATileArray
SFB_DIM0
comptime SFB_DIM0 = sfb_dim0[config]()
SFB_DIM1
comptime SFB_DIM1 = sfb_dim1[config]()
sfb_smem_layout
comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
SFBTileArray
comptime SFBTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFBTileArray
Tiles
comptime Tiles = BlockScaledTileStorage[a_type, b_type, c_type, sfa_dtype, sfb_dtype, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages]
Methods
a_tiles
a_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray
Get A tile array accessor (TileTensor-based).
Returns:
GroupedBlockScaledSmem
b_tiles
b_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray
Get B tile array accessor (TileTensor-based).
Returns:
GroupedBlockScaledSmem
c_tiles
c_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray
Get C tile array accessor (TileTensor-based).
Returns:
GroupedBlockScaledSmem
sfa_tiles
sfa_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray
Get SFA tile array accessor (TileTensor-based).
Returns:
GroupedBlockScaledSmem
sfb_tiles
sfb_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray
Get SFB tile array accessor (TileTensor-based).
Returns:
GroupedBlockScaledSmem
tensormap_storage_size
static tensormap_storage_size() -> Int
Size of tensormap storage in bytes (5 x 128 = 640 bytes).
Returns:
total_tile_size
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!