Mojo struct
TileLoaderTMAIm2col
@register_passable(trivial)
struct TileLoaderTMAIm2col[tma_origin: ImmutOrigin, dtype: DType, gmem_layout: Layout, desc_layout: Layout, /, *, cta_group: Int]
TMA tile loader using hardware im2col for implicit GEMM convolution.
Uses a TMATensorTileIm2col descriptor (cuTensorMapEncodeIm2col) to perform coordinate transformation in TMA hardware. Coordinates are in GEMM space:
- k_coord: K dimension (C * R * S reduction)
- m_coord: M dimension (batch * H_out * W_out spatial)
Parameters
- tma_origin (
ImmutOrigin): Origin of the TMA descriptor pointer. - dtype (
DType): Element data type. - gmem_layout (
Layout): Global memory layout of activation tensor (NHWC). - desc_layout (
Layout): TMA descriptor layout. - cta_group (
Int): CTA group size (1 or 2 for SM100).
Fields
- tma_op (
TileLoaderTMAIm2col[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOpPtr): - multicast_mask (
UInt16):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
TrivialRegisterType
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
TmaOp
comptime TmaOp = TMATensorTileIm2col[dtype, gmem_layout, desc_layout]
TmaOpPtr
comptime TmaOpPtr = Pointer[TileLoaderTMAIm2col[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]
Methods
__init__
__init__(tma_op: Pointer[TileLoaderTMAIm2col[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin], multicast_mask: UInt16) -> Self
load
load[tile_layout: Layout, /, alignment: Int = 128](self, dest: LayoutTensor[dtype, tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, k_coord: Scalar[DType.uint], m_coord: Scalar[DType.uint])
Load an activation tile using im2col TMA.
Note: Uses non-multicast TMA because CUTLASS disables multicast im2col on SM100/SM120 (SM90_TMA_LOAD_IM2COL_MULTICAST has CUTE_INVALID_CONTROL_PATH).
Args:
- dest (
LayoutTensor): Destination SMEM tile. - barrier (
SharedMemBarrier): Memory barrier for TMA completion signaling. - k_coord (
Scalar): K dimension coordinate (C * R * S indexing). - m_coord (
Scalar): M dimension coordinate (batch * H_out * W_out indexing).
load[LayoutType: TensorLayout](self, dest: TileTensor[dtype, LayoutType, MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value._mlir_value] barrier: SharedMemBarrier, k_coord: Scalar[DType.uint], m_coord: Scalar[DType.uint])
Load a TileTensor tile using im2col TMA.
Args:
- dest (
TileTensor): Destination SMEM TileTensor tile. - barrier (
SharedMemBarrier): Memory barrier for TMA completion signaling. - k_coord (
Scalar): K dimension coordinate (C * R * S indexing). - m_coord (
Scalar): M dimension coordinate (batch * H_out * W_out indexing).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!