Skip to main content

Mojo struct

ScalesTileLoader

@register_passable(trivial) struct ScalesTileLoader[tma_origin: ImmutOrigin, dtype: DType, gmem_layout: Layout, desc_layout: Layout, /, *, cta_group: Int]

TMA-based scales tile loader for blockwise FP8.

Unlike TileLoaderTMA, this loader:

  • Uses async_copy (no multicast) since scales aren't distributed across CTAs
  • Uses (row_coord, k_coord) coordinate order matching scales tensor layout

Parameters

  • tma_origin (ImmutOrigin): Origin of the TMA descriptor pointer.
  • dtype (DType): Element data type (typically float8 for scales).
  • gmem_layout (Layout): Global memory tensor layout.
  • desc_layout (Layout): TMA descriptor layout (tile dimensions).
  • cta_group (Int): CTA group size (1 or 2 for SM100 2-SM MMA).

Fields

  • tma_op (ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOpPtr):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

TmaOp

comptime TmaOp = TMATensorTile[dtype, gmem_layout, desc_layout]

TmaOpPtr

comptime TmaOpPtr = Pointer[ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]

Methods

__init__

__init__(tma_op: Pointer[ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]) -> Self

Initialize the scales tile loader.

Args:

  • tma_op (Pointer): Pointer to TMA descriptor (grid constant).

load

load[tile_layout: Layout, /, alignment: Int = 128](self, dest: LayoutTensor[dtype, tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment], ref [3] barrier: SharedMemBarrier, row_coord: Int, k_coord: Int)

Load a scales tile using TMA hardware acceleration.

Issues an async copy from global memory to shared memory. Unlike TileLoaderTMA, this uses (row_coord, k_coord) order matching the scales tensor layout.

Args:

  • dest (LayoutTensor): Destination SMEM tile.
  • barrier (SharedMemBarrier): Memory barrier for TMA completion signaling.
  • row_coord (Int): Row coordinate (M for A-scales) in global memory.
  • k_coord (Int): K dimension coordinate in global memory.

Was this page helpful?