Skip to main content

Mojo struct

TileLoaderTMA

struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank], /, *, BK: Int, cluster_size: Int32, use_partitioned_multicast: Bool]

TMA-based tile loader for hardware-accelerated memory transfers.

This loader uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from global to shared memory, with optional multicast support for multi-block clusters.

Parameters​

  • ​tma_origin (ImmutOrigin): Origin type for the TMA operation.
  • ​dtype (DType): Data type of the elements being loaded.
  • ​tma_rank (Int): Rank of the TMA tile (number of dimensions).
  • ​tile_shape (IndexList[tma_rank]): Shape of the complete tile in shared memory.
  • ​desc_shape (IndexList[tma_rank]): Shape described by the TMA descriptor (may be smaller).
  • ​BK (Int): Block size in the K dimension (for coordinate conversion).
  • ​cluster_size (Int32): Number of blocks in the cluster (1 for no clustering).
  • ​use_partitioned_multicast (Bool): Whether to use partitioned multicast loading.

Fields​

  • ​tma_op (TileLoaderTMA[tma_origin, dtype, tma_rank, tile_shape, desc_shape, BK=BK, cluster_size=cluster_size, use_partitioned_multicast=use_partitioned_multicast].TMATensorTilePtr):
  • ​rank (Int):
  • ​multicast_mask (UInt16):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TileLoader, TrivialRegisterPassable

comptime members​

TMATensorTilePtr​

comptime TMATensorTilePtr = Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]

Methods​

__init__​

__init__(tma_op: Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin], rank: Int, multicast_mask: UInt16) -> Self

Initialize the TMA tile loader.

Args:

load_tile​

load_tile(self, dst: LayoutTensor[dtype, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=128], mem_barrier: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], _coords: Tuple[Int, Int])

Load a tile using TMA hardware acceleration.

Converts tile indices to element coordinates and initiates a TMA transfer. For clusters, uses multicast to share data across blocks.

Note: Coordinates are converted from (row, col) tile indices to (k_elements, row/col_elements) for TMA's K-major ordering.

Args: