Mojo struct
TileLoaderTMA
struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank], /, *, BK: Int, cluster_size: Int32, use_partitioned_multicast: Bool]
TMA-based tile loader for hardware-accelerated memory transfers.
This loader uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from global to shared memory, with optional multicast support for multi-block clusters.
Parametersβ
- βtma_origin (
ImmutOrigin): Origin type for the TMA operation. - βdtype (
DType): Data type of the elements being loaded. - βtma_rank (
Int): Rank of the TMA tile (number of dimensions). - βtile_shape (
IndexList[tma_rank]): Shape of the complete tile in shared memory. - βdesc_shape (
IndexList[tma_rank]): Shape described by the TMA descriptor (may be smaller). - βBK (
Int): Block size in the K dimension (for coordinate conversion). - βcluster_size (
Int32): Number of blocks in the cluster (1 for no clustering). - βuse_partitioned_multicast (
Bool): Whether to use partitioned multicast loading.
Fieldsβ
- βtma_op (
TileLoaderTMA[tma_origin, dtype, tma_rank, tile_shape, desc_shape, BK=BK, cluster_size=cluster_size, use_partitioned_multicast=use_partitioned_multicast].TMATensorTilePtr): - βrank (
Int): - βmulticast_mask (
UInt16):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TileLoader,
TrivialRegisterPassable
comptime membersβ
TMATensorTilePtrβ
comptime TMATensorTilePtr = Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]
Methodsβ
__init__β
__init__(tma_op: Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin], rank: Int, multicast_mask: UInt16) -> Self
Initialize the TMA tile loader.
Args:
- βtma_op (
Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]): Pointer to the TMA tensor descriptor. - βrank (
Int): Rank of this block within the cluster. - βmulticast_mask (
UInt16): Bit mask for multicast targets.
load_tileβ
load_tile(self, dst: LayoutTensor[dtype, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=128], mem_barrier: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], _coords: Tuple[Int, Int])
Load a tile using TMA hardware acceleration.
Converts tile indices to element coordinates and initiates a TMA transfer. For clusters, uses multicast to share data across blocks.
Note: Coordinates are converted from (row, col) tile indices to (k_elements, row/col_elements) for TMA's K-major ordering.
Args:
- βdst (
LayoutTensor[dtype, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=128]): Destination tile in shared memory. - βmem_barrier (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): Memory barrier for synchronization. - β_coords (
Tuple[Int, Int]): Tile coordinates (row_tile_idx, col_tile_idx).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!