Skip to main content

Mojo struct

BlockScaledSmem

struct BlockScaledSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]

SMEM struct containing A/B tiles, scaling factors, C output, and barriers.

Fields

  • a_tiles_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray.Storage):
  • b_tiles_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray.Storage):
  • c_tiles_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray.Storage):
  • sfa_tiles_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray.Storage):
  • sfb_tiles_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray.Storage):
  • tma_mma_mbars_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].InputBarriers.Storage):
  • accum_mbars_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AccumBarriers.Storage):
  • clc_mbars_full_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers.Storage):
  • clc_mbars_empty_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers.Storage):
  • clc_throttle_mbars_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcThrottleBarriers.Storage):
  • clc_response_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcResponse.Storage):
  • tmem_dealloc_mbar_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemDealloc.Storage):
  • tmem_addr_storage (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemAddr.Storage):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

a_smem_layout

comptime a_smem_layout = tile_layout_k_major[a_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, config.a_swizzle]()

AccumBarriers

comptime AccumBarriers = SMemArray[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages * 2)]

ATileArray

comptime ATileArray = SMemTileArray[a_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].a_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]

b_smem_layout

comptime b_smem_layout = tile_layout_k_major[b_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, config.b_swizzle]()

BK

comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)

BTileArray

comptime BTileArray = SMemTileArray[b_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].b_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]

c_smem_layout

comptime c_smem_layout = Layout.row_major(BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN)

ClcBarriers

comptime ClcBarriers = SMemArray[SharedMemBarrier, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages]

ClcResponse

comptime ClcResponse = SMemArray[UInt128, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages]

ClcThrottleBarriers

comptime ClcThrottleBarriers = SMemArray[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages * 2)]

CTileArray

comptime CTileArray = SMemTileArray[c_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].c_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages, 128]

InputBarriers

comptime InputBarriers = SMemArray[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages * 2)]

MMA_M

comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

OutputM

comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)

OutputN

comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)

SF_K_GROUP_SIZE

comptime SF_K_GROUP_SIZE = (4 * config)

sfa_smem_layout

comptime sfa_smem_layout = tile_sf_layout_k_major[BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFATileArray

comptime SFATileArray = SMemTileArray[sfa_dtype, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfa_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]

sfb_smem_layout

comptime sfb_smem_layout = tile_sf_layout_k_major[BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFBTileArray

comptime SFBTileArray = SMemTileArray[sfb_dtype, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfb_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]

TmemAddr

comptime TmemAddr = SMemArray[UInt32, 1]

TmemDealloc

comptime TmemDealloc = SMemArray[SharedMemBarrier, 1]

Methods

a_tiles

a_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray

Returns:

BlockScaledSmem

b_tiles

b_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray

Returns:

BlockScaledSmem

c_tiles

c_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray

Returns:

BlockScaledSmem

sfa_tiles

sfa_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray

Returns:

BlockScaledSmem

sfb_tiles

sfb_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray

Returns:

BlockScaledSmem

tma_mma_mbars

tma_mma_mbars(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].InputBarriers

Returns:

BlockScaledSmem

accum_mbars

accum_mbars(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AccumBarriers

Returns:

BlockScaledSmem

clc_mbars_full

clc_mbars_full(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers

Returns:

BlockScaledSmem

clc_mbars_empty

clc_mbars_empty(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers

Returns:

BlockScaledSmem

clc_throttle_mbars

clc_throttle_mbars(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcThrottleBarriers

Returns:

BlockScaledSmem

clc_response

clc_response(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcResponse

Returns:

BlockScaledSmem

tmem_dealloc_mbar

tmem_dealloc_mbar(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemDealloc

Returns:

BlockScaledSmem

tmem_addr

tmem_addr(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemAddr

Returns:

BlockScaledSmem

input_barriers

input_barriers(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].InputBarriers

Alias for tma_mma_mbars() to match standard Smem API.

Returns:

BlockScaledSmem

accum_barriers

accum_barriers(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AccumBarriers

Alias for accum_mbars() to match standard Smem API.

Returns:

BlockScaledSmem

tmem_dealloc

tmem_dealloc(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemDealloc

Alias for tmem_dealloc_mbar() to match standard Smem API.

Returns:

BlockScaledSmem

ab_pipeline_size

static ab_pipeline_size() -> Int

Total size of A+B tiles for all pipeline stages (in elements).

Returns:

Int

sf_pipeline_size

static sf_pipeline_size() -> Int

Total size of SFA+SFB tiles for all pipeline stages (in elements).

Returns:

Int

c_output_size

static c_output_size() -> Int

Size of C tiles for all output stages (in elements).

Returns:

Int

total_tile_size

static total_tile_size() -> Int

Total tile storage size (A+B+SFA+SFB+C) in elements.

Returns:

Int

Was this page helpful?