IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileLoaderTMA

struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank], /, *, cta_group: Int]

TMA-based tile loader for SM100.

Wraps a TMA descriptor and multicast mask for efficient tile loading. The load method issues async_multicast_load with proper CTA group handling.

Parameters​

  • ​tma_origin (ImmutOrigin): Origin of the TMA descriptor pointer.
  • ​dtype (DType): Element data type.
  • ​tma_rank (Int): Rank of the TMA tile/descriptor shapes.
  • ​tile_shape (IndexList[tma_rank]): TMA tile shape as IndexList.
  • ​desc_shape (IndexList[tma_rank]): TMA descriptor shape as IndexList.
  • ​cta_group (Int): CTA group size (1 or 2 for SM100 2-SM MMA).

Fields​

  • ​tma_op (TileLoaderTMA[tma_origin, dtype, tma_rank, tile_shape, desc_shape, cta_group=cta_group].TmaOpPtr):
  • ​multicast_mask (UInt16):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

TmaOp​

comptime TmaOp = TMATensorTile[dtype, tma_rank, tile_shape, desc_shape]

TmaOpPtr​

comptime TmaOpPtr = Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]

Methods​

__init__​

def __init__(tma_op: Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin], multicast_mask: UInt16) -> Self

Initialize the TMA tile loader.

Args:

load​

def load[dim0: Int, dim1: Int, /, alignment: Int = Int(128)](self, dest: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] barrier: SharedMemBarrier, k_coord: Int, row_coord: Int)

Load a tile using TMA hardware acceleration.

Issues an async multicast load from global memory to shared memory. Coordinates are in element units (not tile units).

Args:

def load[LayoutType: TensorLayout](self, dest: TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] barrier: SharedMemBarrier, k_coord: Int, row_coord: Int)

Load a TileTensor tile with variadic shape/stride types using TMA.

This overload accepts TileTensor tiles with swizzled layouts (created via internal_k_major) and passes them to the TMA operation.

Args: