IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileLoaderTMA

struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank], /, *, BK: Int, cluster_size: Int32, use_partitioned_multicast: Bool]

TMA-based tile loader for hardware-accelerated memory transfers.

This loader uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from global to shared memory, with optional multicast support for multi-block clusters.

Parameters​

  • ​tma_origin (ImmutOrigin): Origin type for the TMA operation.
  • ​dtype (DType): Data type of the elements being loaded.
  • ​tma_rank (Int): Rank of the TMA tile (number of dimensions).
  • ​tile_shape (IndexList[tma_rank]): Shape of the complete tile in shared memory.
  • ​desc_shape (IndexList[tma_rank]): Shape described by the TMA descriptor (may be smaller).
  • ​BK (Int): Block size in the K dimension (for coordinate conversion).
  • ​cluster_size (Int32): Number of blocks in the cluster (1 for no clustering).
  • ​use_partitioned_multicast (Bool): Whether to use partitioned multicast loading.

Fields​

  • ​tma_op (TileLoaderTMA[tma_origin, dtype, tma_rank, tile_shape, desc_shape, BK=BK, cluster_size=cluster_size, use_partitioned_multicast=use_partitioned_multicast].TMATensorTilePtr):
  • ​rank (Int):
  • ​multicast_mask (UInt16):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TileLoader, TrivialRegisterPassable

comptime members​

TMATensorTilePtr​

comptime TMATensorTilePtr = Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]

Methods​

__init__​

def __init__(tma_op: Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin], rank: Int, multicast_mask: UInt16) -> Self

Initialize the TMA tile loader.

Args:

load_tile​

def load_tile(self, dst: TileTensor[Storage=dst.Storage, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], mem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], _coords: Tuple[Int, Int])

Load a tile using TMA hardware acceleration.

Converts tile indices to element coordinates and initiates a TMA transfer. For clusters, uses multicast to share data across blocks.

Note: Coordinates are converted from (row, col) tile indices to (k_elements, row/col_elements) for TMA's K-major ordering.

Args: