Skip to main content

Mojo struct

TMATensorTileIm2col

struct TMATensorTileIm2col[dtype: DType, layout: Layout, desc_layout: Layout = layout]

TMA tensor tile with im2col coordinate transformation for convolution.

This struct enables hardware-accelerated im2col transformation during TMA loads, used for implicit GEMM convolution. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly.

The coordinate system uses GEMM-style 2D coordinates:

  • coords[0]: K coordinate (indexes into R * S * C reduction dimension)
  • coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)

Internally:

  • K is decomposed into (c, r, s) where K = rSC + s*C + c (filter-first, channel-last for NHWC)
  • M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
  • 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction.

Parameters

  • dtype (DType): The data type of tensor elements.
  • layout (Layout): The layout of the tile in shared memory.
  • desc_layout (Layout): The layout of the descriptor (may differ for WGMMA compatibility).

Fields

  • descriptor (TMADescriptor): The TMA descriptor encoding im2col transformation parameters.
  • out_height (UInt32): Output height (H_out) for M coordinate decomposition.
  • out_width (UInt32): Output width (W_out) for M coordinate decomposition.
  • filter_h (UInt32): Filter height (R) for K coordinate decomposition.
  • filter_w (UInt32): Filter width (S) for K coordinate decomposition.
  • in_channels (UInt32): Input channels (C) for K coordinate decomposition.
  • lower_corner_h (Int32): Lower corner offset for height (H dimension) - matches CUTLASS ArithmeticTupleIterator pattern.
  • lower_corner_w (Int32): Lower corner offset for width (W dimension) - matches CUTLASS ArithmeticTupleIterator pattern.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = False

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

device_type

comptime device_type = TMATensorTileIm2col[dtype, layout, desc_layout]

The device-side type representation.

Methods

__init__

__init__(out self, descriptor: TMADescriptor, out_height: UInt32, out_width: UInt32, filter_h: UInt32, filter_w: UInt32, in_channels: UInt32, lower_corner_h: Int32 = 0, lower_corner_w: Int32 = 0)

Initializes with the provided TMA im2col descriptor and dimensions.

Args:

  • descriptor (TMADescriptor): The TMA descriptor that encodes im2col transformation.
  • out_height (UInt32): Output height (H_out) for M coordinate decomposition.
  • out_width (UInt32): Output width (W_out) for M coordinate decomposition.
  • filter_h (UInt32): Filter height (R) for K coordinate decomposition.
  • filter_w (UInt32): Filter width (S) for K coordinate decomposition.
  • in_channels (UInt32): Input channels (C) for K coordinate decomposition.
  • lower_corner_h (Int32): Lower corner offset for H dimension (matches CUTLASS pattern).
  • lower_corner_w (Int32): Lower corner offset for W dimension (matches CUTLASS pattern).

__copyinit__

__copyinit__(out self, other: Self)

Copy initializes from another instance.

Args:

  • other (Self): The other instance to copy from.

get_type_name

static get_type_name() -> String

Gets this type's name for error messages.

Returns:

String: This type's name.

prefetch_descriptor

prefetch_descriptor(self)

Prefetches the TMA descriptor into cache.

async_copy

async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt])

Schedules an asynchronous im2col TMA load.

Uses 2D GEMM-style coordinates:

  • coords[0]: K coordinate (indexes into C * R * S reduction dimension)
  • coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)

Internally:

  • K is decomposed into (c, r, s) where K = cRS + r*S + s
  • M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
  • 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction.

Note: The cta_group parameter defaults to 2 because SM100/Blackwell im2col TMA with padding (negative corners) requires the cta_group::2 PTX format. This is consistent with CUTLASS which only provides SM100_TMA_2SM_LOAD_IM2COL (no cta_group::1 variant for im2col).

Parameters:

  • cta_group (Int): CTA group size for TMA operations.
  • eviction_policy (CacheEviction): Cache eviction policy for the TMA load.

Args:

  • dst (LayoutTensor): Destination tensor in shared memory.
  • mem_barrier (SharedMemBarrier): Memory barrier for synchronization.
  • coords (Tuple): GEMM coordinates (k_coord, m_coord).

async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt])

Schedules an asynchronous im2col TMA load.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Uses 2D GEMM-style coordinates:

  • coords[0]: K coordinate (indexes into C * R * S reduction dimension)
  • coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)

Internally:

  • K is decomposed into (c, r, s) where K = cRS + r*S + s
  • M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
  • 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction.

Note: Uses cta_group=1 (SM90-style TMA) for single-CTA clusters.

Parameters:

  • cta_group (Int): CTA group size for TMA operations.
  • eviction_policy (CacheEviction): Cache eviction policy for the TMA load.

Args:

  • dst (TileTensor): TileTensor in shared memory where data will be copied.
  • mem_barrier (SharedMemBarrier): Memory barrier for synchronization.
  • coords (Tuple): GEMM coordinates (k_coord, m_coord).

async_multicast_load

async_multicast_load[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)

Schedules an asynchronous im2col TMA load with multicast.

Uses 2D GEMM-style coordinates:

  • coords[0]: K coordinate (indexes into C * R * S reduction dimension)
  • coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)

Internally:

  • K is decomposed into (c, r, s) where K = cRS + r*S + s
  • M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
  • 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction with multicast.

Note: The cta_group parameter defaults to 2 because SM100/Blackwell im2col TMA with padding (negative corners) requires the cta_group::2 PTX format. This is consistent with CUTLASS which only provides SM100_TMA_2SM_LOAD_IM2COL_MULTICAST (no cta_group::1 variant).

Parameters:

  • cta_group (Int): CTA group size for TMA operations.
  • eviction_policy (CacheEviction): Cache eviction policy for the TMA load.

Args:

  • dst (LayoutTensor): Destination tensor in shared memory.
  • mem_barrier (SharedMemBarrier): Memory barrier for synchronization.
  • coords (Tuple): GEMM coordinates (k_coord, m_coord).
  • multicast_mask (UInt16): Bitmask specifying target CTAs for multicast.

async_multicast_load[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)

Schedules an asynchronous im2col TMA load with multicast.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Uses 2D GEMM-style coordinates:

  • coords[0]: K coordinate (indexes into C * R * S reduction dimension)
  • coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)

Internally:

  • K is decomposed into (c, r, s) where K = cRS + r*S + s
  • M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
  • 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction with multicast.

Note: Uses cta_group=1 (SM90-style TMA) for single-CTA clusters.

Parameters:

  • cta_group (Int): CTA group size for TMA operations.
  • eviction_policy (CacheEviction): Cache eviction policy for the TMA load.

Args:

  • dst (TileTensor): TileTensor in shared memory where data will be copied.
  • mem_barrier (SharedMemBarrier): Memory barrier for synchronization.
  • coords (Tuple): GEMM coordinates (k_coord, m_coord).
  • multicast_mask (UInt16): Bitmask specifying target CTAs for multicast.

Was this page helpful?