Skip to main content

Mojo struct

TileLoaderTMA

struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank], /, *, cta_group: Int]

TMA-based tile loader for SM100.

Wraps a TMA descriptor and multicast mask for efficient tile loading. The load method issues async_multicast_load with proper CTA group handling.

Parameters​

  • ​tma_origin (ImmutOrigin): Origin of the TMA descriptor pointer.
  • ​dtype (DType): Element data type.
  • ​tma_rank (Int): Rank of the TMA tile/descriptor shapes.
  • ​tile_shape (IndexList[tma_rank]): TMA tile shape as IndexList.
  • ​desc_shape (IndexList[tma_rank]): TMA descriptor shape as IndexList.
  • ​cta_group (Int): CTA group size (1 or 2 for SM100 2-SM MMA).

Fields​

  • ​tma_op (TileLoaderTMA[tma_origin, dtype, tma_rank, tile_shape, desc_shape, cta_group=cta_group].TmaOpPtr):
  • ​multicast_mask (UInt16):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

TmaOp​

comptime TmaOp = TMATensorTile[dtype, tma_rank, tile_shape, desc_shape]

TmaOpPtr​

comptime TmaOpPtr = Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]

Methods​

__init__​

__init__(tma_op: Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin], multicast_mask: UInt16) -> Self

Initialize the TMA tile loader.

Args:

load​

load[dim0: Int, dim1: Int, /, alignment: Int = 128](self, dest: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] barrier: SharedMemBarrier, k_coord: Int, row_coord: Int)

Load a tile using TMA hardware acceleration.

Issues an async multicast load from global memory to shared memory. Coordinates are in element units (not tile units).

Args:

load[LayoutType: TensorLayout](self, dest: TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] barrier: SharedMemBarrier, k_coord: Int, row_coord: Int)

Load a TileTensor tile with variadic shape/stride types using TMA.

This overload accepts TileTensor tiles with swizzled layouts (created via internal_k_major) and passes them to the TMA operation.

Args: