Mojo struct
BlockScaledSmem
struct BlockScaledSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]
SMEM struct containing A/B tiles, scaling factors, C output, and barriers.
Fields
- a_tiles_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray.StorageType): - b_tiles_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray.StorageType): - c_tiles_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray.StorageType): - sfa_tiles_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray.StorageType): - sfb_tiles_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray.StorageType): - tma_mma_mbars_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].InputBarriers.StorageType): - accum_mbars_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AccumBarriers.StorageType): - clc_mbars_full_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers.StorageType): - clc_mbars_empty_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers.StorageType): - clc_throttle_mbars_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcThrottleBarriers.StorageType): - clc_response_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcResponse.StorageType): - tmem_dealloc_mbar_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemDealloc.StorageType): - tmem_addr_storage (
BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemAddr.StorageType):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, config.a_swizzle]()
AccumBarriers
comptime AccumBarriers = SMemArrayType[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_accum_pipeline_stages * 2)]
ATileArray
comptime ATileArray = SMemTileArrayType[a_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].a_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, config.b_swizzle]()
BK
comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)
BTileArray
comptime BTileArray = SMemTileArrayType[b_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].b_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]
c_smem_layout
comptime c_smem_layout = Layout.row_major(BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN)
ClcBarriers
comptime ClcBarriers = SMemArrayType[SharedMemBarrier, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages]
ClcResponse
comptime ClcResponse = SMemArrayType[UInt128, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages]
ClcThrottleBarriers
comptime ClcThrottleBarriers = SMemArrayType[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_clc_pipeline_stages * 2)]
CTileArray
comptime CTileArray = SMemTileArrayType[c_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].c_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages, 128]
InputBarriers
comptime InputBarriers = SMemArrayType[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages * 2)]
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = Int(config)
num_clc_pipeline_stages
comptime num_clc_pipeline_stages = Int(config)
num_group_pipeline_stages
comptime num_group_pipeline_stages = (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages // Int(config))
num_output_stages
comptime num_output_stages = Int(config)
num_pipeline_stages
comptime num_pipeline_stages = Int(config)
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)
sfa_smem_layout
comptime sfa_smem_layout = tile_sf_layout_k_major[BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, 32]()
SFATileArray
comptime SFATileArray = SMemTileArrayType[sfa_dtype, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfa_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]
sfb_smem_layout
comptime sfb_smem_layout = tile_sf_layout_k_major[BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK, 32]()
SFBTileArray
comptime SFBTileArray = SMemTileArrayType[sfb_dtype, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfb_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_pipeline_stages, 128]
TmemAddr
comptime TmemAddr = SMemArrayType[UInt32, 1]
TmemDealloc
comptime TmemDealloc = SMemArrayType[SharedMemBarrier, 1]
Methods
a_tiles
a_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ATileArray
Returns:
BlockScaledSmem
b_tiles
b_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BTileArray
Returns:
BlockScaledSmem
c_tiles
c_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CTileArray
Returns:
BlockScaledSmem
sfa_tiles
sfa_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFATileArray
Returns:
BlockScaledSmem
sfb_tiles
sfb_tiles(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].SFBTileArray
Returns:
BlockScaledSmem
tma_mma_mbars
tma_mma_mbars(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].InputBarriers
Returns:
BlockScaledSmem
accum_mbars
accum_mbars(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AccumBarriers
Returns:
BlockScaledSmem
clc_mbars_full
clc_mbars_full(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers
Returns:
BlockScaledSmem
clc_mbars_empty
clc_mbars_empty(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcBarriers
Returns:
BlockScaledSmem
clc_throttle_mbars
clc_throttle_mbars(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcThrottleBarriers
Returns:
BlockScaledSmem
clc_response
clc_response(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].ClcResponse
Returns:
BlockScaledSmem
tmem_dealloc_mbar
tmem_dealloc_mbar(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemDealloc
Returns:
BlockScaledSmem
tmem_addr
tmem_addr(ref [3] self) -> BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].TmemAddr
Returns:
BlockScaledSmem
ab_pipeline_size
static ab_pipeline_size() -> Int
Total size of A+B tiles for all pipeline stages (in elements).
Returns:
sf_pipeline_size
static sf_pipeline_size() -> Int
Total size of SFA+SFB tiles for all pipeline stages (in elements).
Returns:
c_output_size
total_tile_size
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!