Mojo struct
TileLoaderTMA
struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank], /, *, cta_group: Int]
TMA-based tile loader for SM100.
Wraps a TMA descriptor and multicast mask for efficient tile loading. The load method issues async_multicast_load with proper CTA group handling.
Parametersβ
- βtma_origin (
ImmutOrigin): Origin of the TMA descriptor pointer. - βdtype (
DType): Element data type. - βtma_rank (
Int): Rank of the TMA tile/descriptor shapes. - βtile_shape (
IndexList[tma_rank]): TMA tile shape as IndexList. - βdesc_shape (
IndexList[tma_rank]): TMA descriptor shape as IndexList. - βcta_group (
Int): CTA group size (1 or 2 for SM100 2-SM MMA).
Fieldsβ
- βtma_op (
TileLoaderTMA[tma_origin, dtype, tma_rank, tile_shape, desc_shape, cta_group=cta_group].TmaOpPtr): - βmulticast_mask (
UInt16):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
TmaOpβ
comptime TmaOp = TMATensorTile[dtype, tma_rank, tile_shape, desc_shape]
TmaOpPtrβ
comptime TmaOpPtr = Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]
Methodsβ
__init__β
__init__(tma_op: Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin], multicast_mask: UInt16) -> Self
Initialize the TMA tile loader.
Args:
- βtma_op (
Pointer[TMATensorTile[dtype, tma_rank, tile_shape, desc_shape], tma_origin]): Pointer to TMA descriptor (grid constant). - βmulticast_mask (
UInt16): Multicast mask for cluster distribution.
loadβ
load[dim0: Int, dim1: Int, /, alignment: Int = 128](self, dest: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] barrier: SharedMemBarrier, k_coord: Int, row_coord: Int)
Load a tile using TMA hardware acceleration.
Issues an async multicast load from global memory to shared memory. Coordinates are in element units (not tile units).
Args:
- βdest (
TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]): Destination SMEM TileTensor tile. - βbarrier (
SharedMemBarrier): Memory barrier for TMA completion signaling. - βk_coord (
Int): K dimension coordinate in global memory (elements). - βrow_coord (
Int): Row coordinate (M for A, N for B) in global memory (elements).
load[LayoutType: TensorLayout](self, dest: TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] barrier: SharedMemBarrier, k_coord: Int, row_coord: Int)
Load a TileTensor tile with variadic shape/stride types using TMA.
This overload accepts TileTensor tiles with swizzled layouts (created via internal_k_major) and passes them to the TMA operation.
Args:
- βdest (
TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED]): Destination SMEM TileTensor tile with swizzled layout. - βbarrier (
SharedMemBarrier): Memory barrier for TMA completion signaling. - βk_coord (
Int): K dimension coordinate in global memory (elements). - βrow_coord (
Int): Row coordinate (M for A, N for B) in global memory (elements).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!