Skip to main content

Mojo struct

BlockScaledTmem

@register_passable(trivial) struct BlockScaledTmem[accum_dtype: DType, MMA_M: Int, MMA_N: Int, num_accum_stages: Int, sf_dtype: DType, BM: Int, num_pipeline_stages: Int, *, cta_group: Int = 1, total_cols: Int = 512]

TMEM region for block-scaled matmul with typed tile accessors.

Manages the TMEM address space for block-scaled MMA operations, providing typed TmemTensor access to:

  • Accumulator tiles (one per output pipeline stage)
  • SFA scaling factor tiles (one per k-iteration)
  • SFB scaling factor tiles (one per k-iteration)

Memory layout (512 columns total): ┌────────────────────────────────────────────────────────────┐ │ Accumulators │ SFA Scales │ SFB Scales │ │ (stages × MMA_N) │ (iters × cols) │ (iters × cols) │ └────────────────────────────────────────────────────────────┘

Parameters

  • accum_dtype (DType): Accumulator data type (typically float32).
  • MMA_M (Int): MMA M dimension.
  • MMA_N (Int): MMA N dimension (also stage stride for accumulators).
  • num_accum_stages (Int): Number of accumulator pipeline stages.
  • sf_dtype (DType): Scaling factor data type.
  • BM (Int): Block M dimension (for SFA sizing).
  • num_pipeline_stages (Int): Number of k-iteration pipeline stages.
  • cta_group (Int): CTA group size (1 or 2).
  • total_cols (Int): Total TMEM columns (512 for SM100).

Fields

  • base_addr (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

accum_layout

comptime accum_layout = Layout.row_major(MMA_M, MMA_N)

accum_offset

comptime accum_offset = 0

AccumArray

comptime AccumArray = TmemArrayType[accum_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].accum_layout, num_accum_stages, cta_group=cta_group]

AccumTile

comptime AccumTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumArray.Tile

sfa_layout

comptime sfa_layout = Layout.row_major(1, (BM // 32))

sfa_offset

comptime sfa_offset = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumArray.num_cols

SFAArray

comptime SFAArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfa_layout, num_pipeline_stages, cta_group=cta_group]

SFATile

comptime SFATile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFAArray.Tile

sfb_layout

comptime sfb_layout = Layout.row_major(1, (MMA_N // 32))

sfb_offset

comptime sfb_offset = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfa_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFAArray.num_cols)

SFBArray

comptime SFBArray = TmemArrayType[sf_dtype, BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfb_layout, num_pipeline_stages, cta_group=cta_group]

SFBTile

comptime SFBTile = BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBArray.Tile

used_cols

comptime used_cols = (BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].sfb_offset + BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBArray.num_cols)

Methods

__init__

__init__(base_addr: Int) -> Self

Create TMEM region view at the given base address.

__init__(addr: TmemAddress) -> Self

Create TMEM region view from a TmemAddress.

__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols]) -> Self

Create TMEM region view from a TmemAllocation.

accum_tiles

accum_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumArray

Get array of accumulator tiles.

Returns:

BlockScaledTmem

sfa_tiles

sfa_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFAArray

Get array of SFA scaling factor tiles.

Returns:

BlockScaledTmem

sfb_tiles

sfb_tiles(self) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBArray

Get array of SFB scaling factor tiles.

Returns:

BlockScaledTmem

accum

accum[T: Intable](self, stage: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].AccumTile

Get accumulator tile for the given pipeline stage.

Returns:

BlockScaledTmem

sfa

sfa[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFATile

Get SFA scaling factor tile for the given k-iteration index.

Returns:

BlockScaledTmem

sfb

sfb[T: Intable](self, index: T) -> BlockScaledTmem[accum_dtype, MMA_M, MMA_N, num_accum_stages, sf_dtype, BM, num_pipeline_stages, cta_group=cta_group, total_cols=total_cols].SFBTile

Get SFB scaling factor tile for the given k-iteration index.

Returns:

BlockScaledTmem

Was this page helpful?