Skip to main content

Mojo struct

GroupedBlockScaledSmem

struct GroupedBlockScaledSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]

SMEM struct for grouped block-scaled GEMM.

Extends standard BlockScaledSmem with:

  • 5 TMA descriptor slots for dynamic tensormap updates (A, B, SFA, SFB, C)
  • Each descriptor is 128 bytes with 128-byte alignment

Layout in SMEM:

  1. Tensormap descriptors (5 x 128 bytes = 640 bytes)
  2. A tiles
  3. B tiles
  4. C tiles
  5. SFA tiles
  6. SFB tiles
  7. Pipeline barriers
  8. CLC barriers
  9. TMEM state

Fields

  • tiles (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles):
  • pipelines (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Pipelines):
  • tensormap_a (TMADescriptor):
  • tensormap_b (TMADescriptor):
  • tensormap_sfa (TMADescriptor):
  • tensormap_sfb (TMADescriptor):
  • tensormap_c (TMADescriptor):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

a_smem_layout

comptime a_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.a_smem_layout

ATileArray

comptime ATileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.ATileArray

b_smem_layout

comptime b_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.b_smem_layout

BK

comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)

BTileArray

comptime BTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.BTileArray

c_smem_layout

comptime c_smem_layout = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.c_smem_layout

CTileArray

comptime CTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.CTileArray

Layouts

comptime Layouts = SmemLayouts[a_type, b_type, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]

MMA_M

comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

OutputM

comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)

OutputN

comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)

Pipelines

comptime Pipelines = SmemPipelineBundle[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages]]

SF_BK

comptime SF_BK = sf_bk[config]()

SF_K_GROUP_SIZE

comptime SF_K_GROUP_SIZE = sf_k_group_size[config]()

SFA_DIM0

comptime SFA_DIM0 = sfa_dim0[config]()

SFA_DIM1

comptime SFA_DIM1 = sfa_dim1[config]()

sfa_smem_layout

comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFATileArray

comptime SFATileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFATileArray

SFB_DIM0

comptime SFB_DIM0 = sfb_dim0[config]()

SFB_DIM1

comptime SFB_DIM1 = sfb_dim1[config]()

sfb_smem_layout

comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFBTileArray

comptime SFBTileArray = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFBTileArray

Tiles

comptime Tiles = BlockScaledTileStorage[a_type, b_type, c_type, sfa_dtype, sfb_dtype, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages]

Methods

a_tiles

a_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray

Get A tile array accessor (TileTensor-based).

Returns:

GroupedBlockScaledSmem

b_tiles

b_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray

Get B tile array accessor (TileTensor-based).

Returns:

GroupedBlockScaledSmem

c_tiles

c_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray

Get C tile array accessor (TileTensor-based).

Returns:

GroupedBlockScaledSmem

sfa_tiles

sfa_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray

Get SFA tile array accessor (TileTensor-based).

Returns:

GroupedBlockScaledSmem

sfb_tiles

sfb_tiles(ref[AddressSpace._value._mlir_value] self) -> GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray

Get SFB tile array accessor (TileTensor-based).

Returns:

GroupedBlockScaledSmem

tensormap_storage_size

static tensormap_storage_size() -> Int

Size of tensormap storage in bytes (5 x 128 = 640 bytes).

Returns:

Int

total_tile_size

static total_tile_size() -> Int

Total tile storage size (A+B+SFA+SFB+C) in elements.

Returns:

Int

Was this page helpful?