IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileLoaderTMAIm2col

struct TileLoaderTMAIm2col[tma_origin: ImmutOrigin, dtype: DType, tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank], /, *, cta_group: Int]

TMA tile loader using hardware im2col for implicit GEMM convolution.

Uses a TMATensorTileIm2col descriptor (cuTensorMapEncodeIm2col) to perform coordinate transformation in TMA hardware. Coordinates are in GEMM space:

  • k_coord: K dimension (C * R * S reduction)
  • m_coord: M dimension (batch * H_out * W_out spatial)

Parameters​

  • ​tma_origin (ImmutOrigin): Origin of the TMA descriptor pointer.
  • ​dtype (DType): Element data type.
  • ​tma_rank (Int): Rank of the TMA tile/descriptor shapes.
  • ​tile_shape (IndexList[tma_rank]): TMA tile shape as IndexList.
  • ​desc_shape (IndexList[tma_rank]): TMA descriptor shape as IndexList.
  • ​cta_group (Int): CTA group size (1 or 2 for SM100).

Fields​

  • ​tma_op (TileLoaderTMAIm2col[tma_origin, dtype, tma_rank, tile_shape, desc_shape, cta_group=cta_group].TmaOpPtr):
  • ​multicast_mask (UInt16):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

TmaOp​

comptime TmaOp = TMATensorTileIm2col[dtype, tma_rank, tile_shape, desc_shape]

TmaOpPtr​

comptime TmaOpPtr = Pointer[TMATensorTileIm2col[dtype, tma_rank, tile_shape, desc_shape], tma_origin]

Methods​

__init__​

def __init__(tma_op: Pointer[TMATensorTileIm2col[dtype, tma_rank, tile_shape, desc_shape], tma_origin], multicast_mask: UInt16) -> Self

load​

def load[LayoutType: TensorLayout](self, dest: TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] barrier: SharedMemBarrier, k_coord: Int, m_coord: Int)

Load a TileTensor tile using im2col TMA.

Args: