Skip to main content

Mojo struct

ScatterGatherTMA

@register_passable(trivial) struct ScatterGatherTMA[tma_origin: ImmutableOrigin, dtype: DType, tile_layout: Layout, desc_layout: Layout, /, *, BK: UInt, cluster_size: Int32, use_partitioned_multicast: Bool]

TMA-based tile loader for hardware-accelerated memory transfers.

This loader uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from global to shared memory, with optional multicast support for multi-block clusters.

Parameters

  • tma_origin (ImmutableOrigin): Origin type for the TMA operation.
  • dtype (DType): Data type of the elements being loaded.
  • tile_layout (Layout): Layout of the complete tile in shared memory.
  • desc_layout (Layout): Layout described by the TMA descriptor (may be smaller).
  • BK (UInt): Block size in the K dimension (for coordinate conversion).
  • cluster_size (Int32): Number of blocks in the cluster (1 for no clustering).
  • use_partitioned_multicast (Bool): Whether to use partitioned multicast loading.

Fields

  • tma_op (Pointer[TMATensorTile[dtype, tile_layout, desc_layout], tma_origin]):
  • rank (UInt):
  • multicast_mask (UInt16):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, ScatterGather, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

TMATensorTilePtr

alias TMATensorTilePtr = Pointer[TMATensorTile[dtype, tile_layout, desc_layout], tma_origin]

Methods

__init__

__init__(tma_op: Pointer[TMATensorTile[dtype, tile_layout, desc_layout], tma_origin], rank: UInt, multicast_mask: UInt16) -> Self

Initialize the TMA tile loader.

Args:

  • tma_op (Pointer): Pointer to the TMA tensor descriptor.
  • rank (UInt): Rank of this block within the cluster.
  • multicast_mask (UInt16): Bit mask for multicast targets.

load_tile

load_tile(self, dst: LayoutTensor[dtype, layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], mem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)], _coords: Tuple[UInt, UInt])

Load a tile using TMA hardware acceleration.

Converts tile indices to element coordinates and initiates a TMA transfer. For clusters, uses multicast to share data across blocks.

Note: Coordinates are converted from (row, col) tile indices to (k_elements, row/col_elements) for TMA's K-major ordering.

Args:

  • dst (LayoutTensor): Destination tile in shared memory.
  • mem_barrier (UnsafePointer): Memory barrier for synchronization.
  • _coords (Tuple): Tile coordinates (row_tile_idx, col_tile_idx).

Was this page helpful?