Mojo struct
BlockScaledTmem
@register_passable(trivial)
struct BlockScaledTmem[accum_dtype: DType, MMA_M: Int, MMA_N: Int, num_accum_stages: Int, sf_dtype: DType, BM: Int, num_pipeline_stages: Int, *, cta_group: Int = 1, total_cols: Int = 512]
TMEM region for block-scaled matmul with typed tile accessors.
Manages the TMEM address space for block-scaled MMA operations, providing typed TmemTensor access to:
- Accumulator tiles (one per output pipeline stage)
- SFA scaling factor tiles (one per k-iteration)
- SFB scaling factor tiles (one per k-iteration)
Memory layout (512 columns total): ┌────────────────────────────────────────────────────────────┐ │ Accumulators │ SFA Scales │ SFB Scales │ │ (stages × MMA_N) │ (iters × cols) │ (iters × cols) │ └────────────────────────────────────────────────────────────┘
Parameters
- accum_dtype (
DType): Accumulator data type (typically float32). - MMA_M (
Int): MMA M dimension. - MMA_N (
Int): MMA N dimension (also stage stride for accumulators). - num_accum_stages (
Int): Number of accumulator pipeline stages. - sf_dtype (
DType): Scaling factor data type. - BM (
Int): Block M dimension (for SFA sizing). - num_pipeline_stages (
Int): Number of k-iteration pipeline stages. - cta_group (
Int): CTA group size (1 or 2). - total_cols (
Int): Total TMEM columns (512 for SM100).
Fields
- base_addr (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
accum_layout
comptime accum_layout = Layout.row_major(MMA_M, MMA_N)
accum_offset
comptime accum_offset = 0
AccumArray
comptime AccumArray = TmemArrayType[accum_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].accum_layout, num_accum_stages, cta_group=cta_group]
AccumTile
comptime AccumTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumArray.Tile
sfa_layout
comptime sfa_layout = Layout.row_major(1, (BM // 32))
sfa_offset
comptime sfa_offset = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumArray.num_cols
SFAArray
comptime SFAArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfa_layout, num_pipeline_stages, cta_group=cta_group]
SFATile
comptime SFATile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFAArray.Tile
sfb_layout
comptime sfb_layout = Layout.row_major(1, (MMA_N // 32))
sfb_offset
comptime sfb_offset = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfa_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFAArray.num_cols)
SFBArray
comptime SFBArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfb_layout, num_pipeline_stages, cta_group=cta_group]
SFBTile
comptime SFBTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBArray.Tile
used_cols
comptime used_cols = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfb_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBArray.num_cols)
Methods
__init__
__init__(base_addr: Int) -> Self
Create TMEM region view at the given base address.
__init__(addr: TmemAddress) -> Self
Create TMEM region view from a TmemAddress.
__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols]) -> Self
Create TMEM region view from a TmemAllocation.
accum_tiles
accum_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumArray
Get array of accumulator tiles.
Returns:
sfa_tiles
sfa_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFAArray
Get array of SFA scaling factor tiles.
Returns:
sfb_tiles
sfb_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBArray
Get array of SFB scaling factor tiles.
Returns:
accum
accum[T: Intable](self, stage: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumTile
Get accumulator tile for the given pipeline stage.
Returns:
sfa
sfa[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFATile
Get SFA scaling factor tile for the given k-iteration index.
Returns:
sfb
sfb[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBTile
Get SFB scaling factor tile for the given k-iteration index.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!