Skip to main content

Mojo struct

Grouped1D1DSmem

struct Grouped1D1DSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]

SMEM struct for grouped 1D-1D block-scaled GEMM.

Simplified version of GroupedBlockScaledSmem for offset-based addressing. Uses 3-warp specialization (Load, MMA, Epilogue) without a scheduler warp, so CLC pipeline storage is not needed.

Layout in SMEM:

  1. A tiles (input pipeline stages)
  2. B tiles (input pipeline stages)
  3. C tiles (output stages)
  4. SFA tiles (scaling factors for A)
  5. SFB tiles (scaling factors for B)
  6. Input pipeline barriers
  7. Output pipeline barriers (accum barriers)
  8. TMEM deallocation state

Fields

  • tiles (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles):
  • pipelines (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Pipelines):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

a_smem_layout

comptime a_smem_layout = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.a_smem_layout

ATileArray

comptime ATileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.ATileArray

b_smem_layout

comptime b_smem_layout = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.b_smem_layout

BK

comptime BK = config.block_tile_shape.__getitem__[Int](2)

BM

comptime BM = config.block_tile_shape.__getitem__[Int](0)

BN

comptime BN = config.block_tile_shape.__getitem__[Int](1)

BTileArray

comptime BTileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.BTileArray

c_smem_layout

comptime c_smem_layout = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Layouts.c_smem_layout

CTileArray

comptime CTileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.CTileArray

Layouts

comptime Layouts = SmemLayouts[a_type, b_type, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, config.a_swizzle, config.b_swizzle, transpose_b]

MMA_M

comptime MMA_M = config.mma_shape.__getitem__[Int](0)

MMA_N

comptime MMA_N = config.mma_shape.__getitem__[Int](1)

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

OutputM

comptime OutputM = config.output_tile_shape.__getitem__[Int](0)

OutputN

comptime OutputN = config.output_tile_shape.__getitem__[Int](1)

Pipelines

comptime Pipelines = SmemPipelineBundleNoClc[Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages]]

SF_BK

comptime SF_BK = sf_bk[config]()

SF_K_GROUP_SIZE

comptime SF_K_GROUP_SIZE = sf_k_group_size[config]()

SFA_DIM0

comptime SFA_DIM0 = sfa_dim0[config]()

SFA_DIM1

comptime SFA_DIM1 = sfa_dim1[config]()

sfa_smem_layout

comptime sfa_smem_layout = tile_sf_layout_k_major[Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFATileArray

comptime SFATileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFATileArray

SFB_DIM0

comptime SFB_DIM0 = sfb_dim0[config]()

SFB_DIM1

comptime SFB_DIM1 = sfb_dim1[config]()

sfb_smem_layout

comptime sfb_smem_layout = tile_sf_layout_k_major[Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, (Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFBTileArray

comptime SFBTileArray = Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Tiles.SFBTileArray

Tiles

comptime Tiles = BlockScaledTileStorage[a_type, b_type, c_type, sfa_dtype, sfb_dtype, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFA_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM0, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFB_DIM1, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages]

Methods

a_tiles

a_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray

Get A tile array accessor.

Returns:

Grouped1D1DSmem

b_tiles

b_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray

Get B tile array accessor.

Returns:

Grouped1D1DSmem

c_tiles

c_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray

Get C tile array accessor.

Returns:

Grouped1D1DSmem

sfa_tiles

sfa_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray

Get SFA tile array accessor.

Returns:

Grouped1D1DSmem

sfb_tiles

sfb_tiles(ref[AddressSpace._value._mlir_value] self) -> Grouped1D1DSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray

Get SFB tile array accessor.

Returns:

Grouped1D1DSmem

ab_pipeline_size

static ab_pipeline_size() -> Int

Total size of A+B tiles for all pipeline stages (in elements).

Returns:

Int

sf_pipeline_size

static sf_pipeline_size() -> Int

Total size of SFA+SFB tiles for all pipeline stages (in elements).

Returns:

Int

c_output_size

static c_output_size() -> Int

Size of C tiles for all output stages (in elements).

Returns:

Int

total_tile_size

static total_tile_size() -> Int

Total tile storage size (A+B+SFA+SFB+C) in elements.

Returns:

Int

Was this page helpful?