Mojo struct
TMATensorTileIm2col
struct TMATensorTileIm2col[dtype: DType, layout: Layout, desc_layout: Layout = layout]
TMA tensor tile with im2col coordinate transformation for convolution.
This struct enables hardware-accelerated im2col transformation during TMA loads, used for implicit GEMM convolution. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly.
The coordinate system uses GEMM-style 2D coordinates:
- coords[0]: K coordinate (indexes into R * S * C reduction dimension)
- coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)
Internally:
- K is decomposed into (c, r, s) where K = rSC + s*C + c (filter-first, channel-last for NHWC)
- M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
- 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction.
Parameters
- dtype (
DType): The data type of tensor elements. - layout (
Layout): The layout of the tile in shared memory. - desc_layout (
Layout): The layout of the descriptor (may differ for WGMMA compatibility).
Fields
- descriptor (
TMADescriptor): The TMA descriptor encoding im2col transformation parameters. - out_height (
UInt32): Output height (H_out) for M coordinate decomposition. - out_width (
UInt32): Output width (W_out) for M coordinate decomposition. - filter_h (
UInt32): Filter height (R) for K coordinate decomposition. - filter_w (
UInt32): Filter width (S) for K coordinate decomposition. - in_channels (
UInt32): Input channels (C) for K coordinate decomposition. - lower_corner_h (
Int32): Lower corner offset for height (H dimension) - matches CUTLASS ArithmeticTupleIterator pattern. - lower_corner_w (
Int32): Lower corner offset for width (W dimension) - matches CUTLASS ArithmeticTupleIterator pattern.
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = False
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
device_type
comptime device_type = TMATensorTileIm2col[dtype, layout, desc_layout]
The device-side type representation.
Methods
__init__
__init__(out self, descriptor: TMADescriptor, out_height: UInt32, out_width: UInt32, filter_h: UInt32, filter_w: UInt32, in_channels: UInt32, lower_corner_h: Int32 = 0, lower_corner_w: Int32 = 0)
Initializes with the provided TMA im2col descriptor and dimensions.
Args:
- descriptor (
TMADescriptor): The TMA descriptor that encodes im2col transformation. - out_height (
UInt32): Output height (H_out) for M coordinate decomposition. - out_width (
UInt32): Output width (W_out) for M coordinate decomposition. - filter_h (
UInt32): Filter height (R) for K coordinate decomposition. - filter_w (
UInt32): Filter width (S) for K coordinate decomposition. - in_channels (
UInt32): Input channels (C) for K coordinate decomposition. - lower_corner_h (
Int32): Lower corner offset for H dimension (matches CUTLASS pattern). - lower_corner_w (
Int32): Lower corner offset for W dimension (matches CUTLASS pattern).
__copyinit__
__copyinit__(out self, other: Self)
Copy initializes from another instance.
Args:
- other (
Self): The other instance to copy from.
get_type_name
static get_type_name() -> String
Gets this type's name for error messages.
Returns:
String: This type's name.
prefetch_descriptor
prefetch_descriptor(self)
Prefetches the TMA descriptor into cache.
async_copy
async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt])
Schedules an asynchronous im2col TMA load.
Uses 2D GEMM-style coordinates:
- coords[0]: K coordinate (indexes into C * R * S reduction dimension)
- coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)
Internally:
- K is decomposed into (c, r, s) where K = cRS + r*S + s
- M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
- 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction.
Note: The cta_group parameter defaults to 2 because SM100/Blackwell im2col TMA with padding (negative corners) requires the cta_group::2 PTX format. This is consistent with CUTLASS which only provides SM100_TMA_2SM_LOAD_IM2COL (no cta_group::1 variant for im2col).
Parameters:
- cta_group (
Int): CTA group size for TMA operations. - eviction_policy (
CacheEviction): Cache eviction policy for the TMA load.
Args:
- dst (
LayoutTensor): Destination tensor in shared memory. - mem_barrier (
SharedMemBarrier): Memory barrier for synchronization. - coords (
Tuple): GEMM coordinates (k_coord, m_coord).
async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt])
Schedules an asynchronous im2col TMA load.
TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).
Uses 2D GEMM-style coordinates:
- coords[0]: K coordinate (indexes into C * R * S reduction dimension)
- coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)
Internally:
- K is decomposed into (c, r, s) where K = cRS + r*S + s
- M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
- 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction.
Note: Uses cta_group=1 (SM90-style TMA) for single-CTA clusters.
Parameters:
- cta_group (
Int): CTA group size for TMA operations. - eviction_policy (
CacheEviction): Cache eviction policy for the TMA load.
Args:
- dst (
TileTensor): TileTensor in shared memory where data will be copied. - mem_barrier (
SharedMemBarrier): Memory barrier for synchronization. - coords (
Tuple): GEMM coordinates (k_coord, m_coord).
async_multicast_load
async_multicast_load[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)
Schedules an asynchronous im2col TMA load with multicast.
Uses 2D GEMM-style coordinates:
- coords[0]: K coordinate (indexes into C * R * S reduction dimension)
- coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)
Internally:
- K is decomposed into (c, r, s) where K = cRS + r*S + s
- M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
- 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction with multicast.
Note: The cta_group parameter defaults to 2 because SM100/Blackwell im2col TMA with padding (negative corners) requires the cta_group::2 PTX format. This is consistent with CUTLASS which only provides SM100_TMA_2SM_LOAD_IM2COL_MULTICAST (no cta_group::1 variant).
Parameters:
- cta_group (
Int): CTA group size for TMA operations. - eviction_policy (
CacheEviction): Cache eviction policy for the TMA load.
Args:
- dst (
LayoutTensor): Destination tensor in shared memory. - mem_barrier (
SharedMemBarrier): Memory barrier for synchronization. - coords (
Tuple): GEMM coordinates (k_coord, m_coord). - multicast_mask (
UInt16): Bitmask specifying target CTAs for multicast.
async_multicast_load[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, LayoutType, origin, address_space=AddressSpace.SHARED, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ref[AddressSpace._value._mlir_value] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)
Schedules an asynchronous im2col TMA load with multicast.
TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).
Uses 2D GEMM-style coordinates:
- coords[0]: K coordinate (indexes into C * R * S reduction dimension)
- coords[1]: M coordinate (indexes into batch * H_out * W_out spatial)
Internally:
- K is decomposed into (c, r, s) where K = cRS + r*S + s
- M is decomposed into (n, h, w) where M = nH_outW_out + h*W_out + w
- 4D coordinates (c, w, h, n) and filter offsets (s, r) are passed to the PTX im2col instruction with multicast.
Note: Uses cta_group=1 (SM90-style TMA) for single-CTA clusters.
Parameters:
- cta_group (
Int): CTA group size for TMA operations. - eviction_policy (
CacheEviction): Cache eviction policy for the TMA load.
Args:
- dst (
TileTensor): TileTensor in shared memory where data will be copied. - mem_barrier (
SharedMemBarrier): Memory barrier for synchronization. - coords (
Tuple): GEMM coordinates (k_coord, m_coord). - multicast_mask (
UInt16): Bitmask specifying target CTAs for multicast.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!