Skip to main content

Mojo struct

TileLoaderTMAIm2col

@register_passable(trivial) struct TileLoaderTMAIm2col[tma_origin: ImmutOrigin, dtype: DType, gmem_layout: Layout, desc_layout: Layout, /, *, cta_group: Int]

TMA tile loader using hardware im2col for implicit GEMM convolution.

Uses a TMATensorTileIm2col descriptor (cuTensorMapEncodeIm2col) to perform coordinate transformation in TMA hardware. Coordinates are in GEMM space:

  • k_coord: K dimension (C * R * S reduction)
  • m_coord: M dimension (batch * H_out * W_out spatial)

Parameters

  • tma_origin (ImmutOrigin): Origin of the TMA descriptor pointer.
  • dtype (DType): Element data type.
  • gmem_layout (Layout): Global memory layout of activation tensor (NHWC).
  • desc_layout (Layout): TMA descriptor layout.
  • cta_group (Int): CTA group size (1 or 2 for SM100).

Fields

  • tma_op (TileLoaderTMAIm2col[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOpPtr):
  • multicast_mask (UInt16):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, TrivialRegisterType

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

TmaOp

comptime TmaOp = TMATensorTileIm2col[dtype, gmem_layout, desc_layout]

TmaOpPtr

comptime TmaOpPtr = Pointer[TileLoaderTMAIm2col[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]

Methods

__init__

__init__(tma_op: Pointer[TileLoaderTMAIm2col[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin], multicast_mask: UInt16) -> Self

load

load[tile_layout: Layout, /, alignment: Int = 128](self, dest: LayoutTensor[dtype, tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, k_coord: Scalar[DType.uint], m_coord: Scalar[DType.uint])

Load an activation tile using im2col TMA.

Note: Uses non-multicast TMA because CUTLASS disables multicast im2col on SM100/SM120 (SM90_TMA_LOAD_IM2COL_MULTICAST has CUTE_INVALID_CONTROL_PATH).

Args:

  • dest (LayoutTensor): Destination SMEM tile.
  • barrier (SharedMemBarrier): Memory barrier for TMA completion signaling.
  • k_coord (Scalar): K dimension coordinate (C * R * S indexing).
  • m_coord (Scalar): M dimension coordinate (batch * H_out * W_out indexing).

load[LayoutType: TensorLayout](self, dest: TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, k_coord: Scalar[DType.uint], m_coord: Scalar[DType.uint])

Load a TileTensor tile using im2col TMA.

Args:

  • dest (TileTensor): Destination SMEM TileTensor tile.
  • barrier (SharedMemBarrier): Memory barrier for TMA completion signaling.
  • k_coord (Scalar): K dimension coordinate (C * R * S indexing).
  • m_coord (Scalar): M dimension coordinate (batch * H_out * W_out indexing).

Was this page helpful?