IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlockScaledTmem

struct BlockScaledTmem[accum_dtype: DType, MMA_M: Int, MMA_N: Int, num_accum_stages: Int, sf_dtype: DType, BM: Int, num_pipeline_stages: Int, *, cta_group: Int = Int(1), total_cols: Int = Int(512), num_sf_k_tiles: Int = Int(1), SFB_N: Int = MMA_N]

TMEM region for block-scaled matmul with typed tile accessors.

Manages the TMEM address space for block-scaled MMA operations, providing typed TmemTensor access to:

  • Accumulator tiles (one per output pipeline stage)
  • SFA scaling factor tiles (one per k-iteration)
  • SFB scaling factor tiles (one per k-iteration)

Memory layout (512 columns total): β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Accumulators β”‚ SFA Scales β”‚ SFB Scales β”‚ β”‚ (stages Γ— MMA_N) β”‚ (iters Γ— cols) β”‚ (iters Γ— cols) β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Parameters​

  • ​accum_dtype (DType): Accumulator data type (typically float32).
  • ​MMA_M (Int): MMA M dimension.
  • ​MMA_N (Int): MMA N dimension (also stage stride for accumulators).
  • ​num_accum_stages (Int): Number of accumulator pipeline stages.
  • ​sf_dtype (DType): Scaling factor data type.
  • ​BM (Int): Block M dimension (for SFA sizing).
  • ​num_pipeline_stages (Int): Number of k-iteration pipeline stages.
  • ​cta_group (Int): CTA group size (1 or 2).
  • ​total_cols (Int): Total TMEM columns (512 for SM100).
  • ​num_sf_k_tiles (Int): Scaling factor tiles per K-iteration. MXFP8 uses 1 (one SF vector per K-tile). NVFP4 uses 4 (multiple SF vectors per K-tile).
  • ​SFB_N (Int): SFB N dimension for TMEM layout. Defaults to MMA_N. Set to align_up(MMA_N, SF_MN_GROUP_SIZE) when MMA_N < SF_MN_GROUP_SIZE so the TMEM tile is wide enough for the SMEM-to-TMEM copy (which always writes a full SF_MN_GROUP_SIZE group).

Fields​

  • ​base_addr (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_layout​

comptime accum_layout = Layout.row_major(MMA_M, MMA_N)

accum_offset​

comptime accum_offset = 0

AccumArray​

comptime AccumArray = TmemArrayType[accum_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].accum_layout, num_accum_stages, cta_group=cta_group]

AccumTile​

comptime AccumTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumArray.Tile

sfa_layout​

comptime sfa_layout = Layout.row_major(Int(1), (num_sf_k_tiles * (BM // Int(32))))

sfa_offset​

comptime sfa_offset = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].AccumArray.num_cols

SFAArray​

comptime SFAArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfa_layout, num_pipeline_stages, cta_group=cta_group]

SFATile​

comptime SFATile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFAArray.Tile

sfb_layout​

comptime sfb_layout = Layout.row_major(Int(1), (num_sf_k_tiles * (SFB_N // Int(32))))

sfb_offset​

comptime sfb_offset = (Int((mul Layout.row_major(MMA_M, MMA_N).shape[1].value(), num_accum_stages)) + Int((mul Layout.row_major(Int(1), Int((mul (BM // Int(32)), num_sf_k_tiles))).shape[1].value(), num_pipeline_stages)))

SFBArray​

comptime SFBArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].sfb_layout, num_pipeline_stages, cta_group=cta_group]

SFBTile​

comptime SFBTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols, num_sf_k_tiles=num_sf_k_tiles, SFB_N=SFB_N].SFBArray.Tile

used_cols​

comptime used_cols = (Int((add (mul Layout.row_major(MMA_M, MMA_N).shape[1].value(), num_accum_stages), (mul Layout.row_major(Int(1), Int((mul (BM // Int(32)), num_sf_k_tiles))).shape[1].value(), num_pipeline_stages))) + Int((mul Layout.row_major(Int(1), Int((mul (SFB_N // Int(32)), num_sf_k_tiles))).shape[1].value(), num_pipeline_stages)))

Methods​

__init__​

def __init__(base_addr: Int) -> Self

Create TMEM region view at the given base address.

def __init__(addr: TmemAddress) -> Self

Create TMEM region view from a TmemAddress.

def __init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols]) -> Self

Create TMEM region view from a TmemAllocation.

accum_tiles​

def accum_tiles(self) -> Self.AccumArray

Get array of accumulator tiles.

Returns:

Self.AccumArray

sfa_tiles​

def sfa_tiles(self) -> Self.SFAArray

Get array of SFA scaling factor tiles.

Returns:

Self.SFAArray

sfb_tiles​

def sfb_tiles(self) -> Self.SFBArray

Get array of SFB scaling factor tiles.

Returns:

Self.SFBArray

accum​

def accum[T: Intable](self, stage: T) -> Self.AccumTile

Get accumulator tile for the given pipeline stage.

Returns:

Self.AccumTile

sfa​

def sfa[T: Intable](self, index: T) -> Self.SFATile

Get SFA scaling factor tile for the given k-iteration index.

Returns:

Self.SFATile

sfb​

def sfb[T: Intable](self, index: T) -> Self.SFBTile

Get SFB scaling factor tile for the given k-iteration index.

Returns:

Self.SFBTile