Mojo struct
BlockScaledTmem
struct BlockScaledTmem[accum_dtype: DType, MMA_M: Int, MMA_N: Int, num_accum_stages: Int, sf_dtype: DType, BM: Int, num_pipeline_stages: Int, *, cta_group: Int = 1, total_cols: Int = 512, num_sf_k_tiles: Int = 1, SFB_N: Int = MMA_N]
TMEM region for block-scaled matmul with typed tile accessors.
Manages the TMEM address space for block-scaled MMA operations, providing typed TmemTensor access to:
- Accumulator tiles (one per output pipeline stage)
- SFA scaling factor tiles (one per k-iteration)
- SFB scaling factor tiles (one per k-iteration)
Memory layout (512 columns total): ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β Accumulators β SFA Scales β SFB Scales β β (stages Γ MMA_N) β (iters Γ cols) β (iters Γ cols) β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Parametersβ
- βaccum_dtype (
DType): Accumulator data type (typically float32). - βMMA_M (
Int): MMA M dimension. - βMMA_N (
Int): MMA N dimension (also stage stride for accumulators). - βnum_accum_stages (
Int): Number of accumulator pipeline stages. - βsf_dtype (
DType): Scaling factor data type. - βBM (
Int): Block M dimension (for SFA sizing). - βnum_pipeline_stages (
Int): Number of k-iteration pipeline stages. - βcta_group (
Int): CTA group size (1 or 2). - βtotal_cols (
Int): Total TMEM columns (512 for SM100). - βnum_sf_k_tiles (
Int): Scaling factor tiles per K-iteration. MXFP8 uses 1 (one SF vector per K-tile). NVFP4 uses 4 (multiple SF vectors per K-tile). - βSFB_N (
Int): SFB N dimension for TMEM layout. Defaults to MMA_N. Set to align_up(MMA_N, SF_MN_GROUP_SIZE) when MMA_N < SF_MN_GROUP_SIZE so the TMEM tile is wide enough for the SMEM-to-TMEM copy (which always writes a full SF_MN_GROUP_SIZE group).
Fieldsβ
- βbase_addr (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_layoutβ
comptime accum_layout = Layout.row_major(MMA_M, MMA_N)
accum_offsetβ
comptime accum_offset = 0
AccumArrayβ
comptime AccumArray = TmemArrayType[accum_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].accum_layout, num_accum_stages, cta_group=cta_group]
AccumTileβ
comptime AccumTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumArray.Tile
sfa_layoutβ
comptime sfa_layout = Layout.row_major(1, (num_sf_k_tiles * (BM // 32)))
sfa_offsetβ
comptime sfa_offset = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumArray.num_cols
SFAArrayβ
comptime SFAArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfa_layout, num_pipeline_stages, cta_group=cta_group]
SFATileβ
comptime SFATile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFAArray.Tile
sfb_layoutβ
comptime sfb_layout = Layout.row_major(1, (num_sf_k_tiles * (SFB_N // 32)))
sfb_offsetβ
comptime sfb_offset = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfa_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFAArray.num_cols)
SFBArrayβ
comptime SFBArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfb_layout, num_pipeline_stages, cta_group=cta_group]
SFBTileβ
comptime SFBTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFBArray.Tile
used_colsβ
comptime used_cols = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfb_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFBArray.num_cols)
Methodsβ
__init__β
__init__(base_addr: Int) -> Self
Create TMEM region view at the given base address.
__init__(addr: TmemAddress) -> Self
Create TMEM region view from a TmemAddress.
__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols]) -> Self
Create TMEM region view from a TmemAllocation.
accum_tilesβ
accum_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumArray
Get array of accumulator tiles.
Returns:
sfa_tilesβ
sfa_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFAArray
Get array of SFA scaling factor tiles.
Returns:
sfb_tilesβ
sfb_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFBArray
Get array of SFB scaling factor tiles.
Returns:
accumβ
accum[T: Intable](self, stage: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumTile
Get accumulator tile for the given pipeline stage.
Returns:
sfaβ
sfa[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFATile
Get SFA scaling factor tile for the given k-iteration index.
Returns:
sfbβ
sfb[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFBTile
Get SFB scaling factor tile for the given k-iteration index.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!