Mojo function
create_tensor_tile_im2col
create_tensor_tile_im2col[dtype: DType, tile_shape: IndexList[2], swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, *, __tile_layout: Layout = Layout.row_major(tile_shape.__getitem__[2, DType.int64, Int](0), tile_shape.__getitem__[2, DType.int64, Int](1)), __desc_layout: Layout = _im2col_desc_tile_layout[dtype, tile_shape, swizzle_mode]()](ctx: DeviceContext, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lower_corner_h: Int, lower_corner_w: Int, upper_corner_h: Int, upper_corner_w: Int, out_height: Int, out_width: Int, filter_h: Int, filter_w: Int) -> TMATensorTileIm2col[dtype, __tile_layout, __desc_layout]
Creates a TMA tensor tile with im2col transformation for 2D convolution.
This factory function creates a TMA descriptor that performs hardware im2col transformation during loads. The descriptor encodes the convolution geometry and the TMA hardware computes addresses on-the-fly.
For im2col TMA, each transaction loads one output pixel with multiple channels. This follows CUTLASS's approach where:
- pixels_per_column = 1 (one pixel per TMA transaction)
- channels_per_pixel = min(K_tile, swizzle_width) (contiguous channels)
Note: For stride=1, dilation=1 convolution with padding (following CUTLASS convention):
- lower_corner_h = -pad_h
- lower_corner_w = -pad_w
- upper_corner_h = pad_h - (filter_h - 1)
- upper_corner_w = pad_w - (filter_w - 1)
The filter offsets passed to the PTX instruction range from 0 to (filter_size - 1) and are added to lower_corner to compute actual input coordinates.
Parameters:
- dtype (
DType): The data type of tensor elements. - tile_shape (
IndexList): Shape[M_tile, K_tile]for the GEMM tile.- M_tile: Number of output pixels (batch * H_out * W_out slice).
- K_tile: Number of channels (C_in * R * S slice for filter).
- swizzle_mode (
TensorMapSwizzle): Memory swizzling pattern. - __tile_layout (
Layout): Internal layout parameter (full tile shape). - __desc_layout (
Layout): Internal descriptor layout parameter (TMA box shape).
Args:
- ctx (
DeviceContext): The CUDA device context. - tensor (
LayoutTensor): The 4D activation tensor in NHWC layout. - lower_corner_h (
Int): Lower corner offset for height (negative for padding). - lower_corner_w (
Int): Lower corner offset for width (negative for padding). - upper_corner_h (
Int): Upper corner offset for height. - upper_corner_w (
Int): Upper corner offset for width. - out_height (
Int): Output height (H_out) for M coordinate decomposition. - out_width (
Int): Output width (W_out) for M coordinate decomposition. - filter_h (
Int): Filter height (R) for K coordinate decomposition. - filter_w (
Int): Filter width (S) for K coordinate decomposition.
Returns:
TMATensorTileIm2col: A TMATensorTileIm2col configured for im2col loads.
Raises:
Error if TMA descriptor creation fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!