For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
BlockScaledTmem
struct BlockScaledTmem[accum_dtype: DType, MMA_M: Int, MMA_N: Int, num_accum_stages: Int, sf_dtype: DType, BM: Int, num_pipeline_stages: Int, *, cta_group: Int = Int(1), total_cols: Int = Int(512), num_sf_k_tiles: Int = Int(1), SFB_N: Int = MMA_N]
TMEM region for block-scaled matmul with typed tile accessors.
Manages the TMEM address space for block-scaled MMA operations, providing typed TmemTensor access to:
- Accumulator tiles (one per output pipeline stage)
- SFA scaling factor tiles (one per k-iteration)
- SFB scaling factor tiles (one per k-iteration)
Memory layout (512 columns total): ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β Accumulators β SFA Scales β SFB Scales β β (stages Γ MMA_N) β (iters Γ cols) β (iters Γ cols) β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Parametersβ
- βaccum_dtype (
DType): Accumulator data type (typically float32). - βMMA_M (
Int): MMA M dimension. - βMMA_N (
Int): MMA N dimension (also stage stride for accumulators). - βnum_accum_stages (
Int): Number of accumulator pipeline stages. - βsf_dtype (
DType): Scaling factor data type. - βBM (
Int): Block M dimension (for SFA sizing). - βnum_pipeline_stages (
Int): Number of k-iteration pipeline stages. - βcta_group (
Int): CTA group size (1 or 2). - βtotal_cols (
Int): Total TMEM columns (512 for SM100). - βnum_sf_k_tiles (
Int): Scaling factor tiles per K-iteration. MXFP8 uses 1 (one SF vector per K-tile). NVFP4 uses 4 (multiple SF vectors per K-tile). - βSFB_N (
Int): SFB N dimension for TMEM layout. Defaults to MMA_N. Set to align_up(MMA_N, SF_MN_GROUP_SIZE) when MMA_N < SF_MN_GROUP_SIZE so the TMEM tile is wide enough for the SMEM-to-TMEM copy (which always writes a full SF_MN_GROUP_SIZE group).
Fieldsβ
- βbase_addr (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_layoutβ
comptime accum_layout = Layout.row_major(MMA_M, MMA_N)
accum_offsetβ
comptime accum_offset = 0
AccumArrayβ
comptime AccumArray = TmemArrayType[accum_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].accum_layout, num_accum_stages, cta_group=cta_group]
AccumTileβ
comptime AccumTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumArray.Tile
sfa_layoutβ
comptime sfa_layout = Layout.row_major(Int(1), (num_sf_k_tiles * (BM // Int(32))))
sfa_offsetβ
comptime sfa_offset = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumArray.num_cols
SFAArrayβ
comptime SFAArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfa_layout, num_pipeline_stages, cta_group=cta_group]
SFATileβ
comptime SFATile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFAArray.Tile
sfb_layoutβ
comptime sfb_layout = Layout.row_major(Int(1), (num_sf_k_tiles * (SFB_N // Int(32))))
sfb_offsetβ
comptime sfb_offset = (Int((mul Layout.row_major(MMA_M, MMA_N).shape[1].value(), num_accum_stages)) + Int((mul Layout.row_major(Int(1), Int((mul (BM // Int(32)), num_sf_k_tiles))).shape[1].value(), num_pipeline_stages)))
SFBArrayβ
comptime SFBArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfb_layout, num_pipeline_stages, cta_group=cta_group]
SFBTileβ
comptime SFBTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFBArray.Tile
used_colsβ
comptime used_cols = (Int((add (mul Layout.row_major(MMA_M, MMA_N).shape[1].value(), num_accum_stages), (mul Layout.row_major(Int(1), Int((mul (BM // Int(32)), num_sf_k_tiles))).shape[1].value(), num_pipeline_stages))) + Int((mul Layout.row_major(Int(1), Int((mul (SFB_N // Int(32)), num_sf_k_tiles))).shape[1].value(), num_pipeline_stages)))
Methodsβ
__init__β
def __init__(base_addr: Int) -> Self
Create TMEM region view at the given base address.
def __init__(addr: TmemAddress) -> Self
Create TMEM region view from a TmemAddress.
def __init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols]) -> Self
Create TMEM region view from a TmemAllocation.
accum_tilesβ
def accum_tiles(self) -> Self.AccumArray
Get array of accumulator tiles.
Returns:
Self.AccumArray
sfa_tilesβ
def sfa_tiles(self) -> Self.SFAArray
Get array of SFA scaling factor tiles.
Returns:
Self.SFAArray
sfb_tilesβ
def sfb_tiles(self) -> Self.SFBArray
Get array of SFB scaling factor tiles.
Returns:
Self.SFBArray
accumβ
def accum[T: Intable](self, stage: T) -> Self.AccumTile
Get accumulator tile for the given pipeline stage.
Returns:
Self.AccumTile
sfaβ
def sfa[T: Intable](self, index: T) -> Self.SFATile
Get SFA scaling factor tile for the given k-iteration index.
Returns:
Self.SFATile
sfbβ
def sfb[T: Intable](self, index: T) -> Self.SFBTile
Get SFB scaling factor tile for the given k-iteration index.
Returns:
Self.SFBTile
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!