Skip to main content

Mojo function

create_tma_tile_gather4

create_tma_tile_gather4[dtype: DType, *, tile_height: Int = 4, tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](ctx: DeviceContext, device_buf: DeviceBuffer[dtype], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(tile_height, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple()), IndexList(1, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple())]

Creates a TMATensorTile for gather4 with automatic box-width computation.

The global tensor has tile_width elements per row. The TMA box width is derived from the swizzle mode so that each gather4 call loads one swizzle-group-sized column chunk (for SWIZZLE_NONE the box equals the full row). The caller iterates over column groups using the col_idx parameter of async_copy_gather4::

for cg in range(_gather4_num_col_groups[dtype, tile_width, swizzle_mode]()):
    tile.async_copy_gather4(dst, bar, col_idx=Int32(cg * box_width),
                            row0, row1, row2, row3)

Alternatively, use async_copy_gather4_tile to load the full tile_height-row tile in one call (it loops over 4-row chunks and column groups internally).

Parameters:

  • dtype (DType): The element data type.
  • tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • tile_width (Int): Number of elements per row to load (box width).
  • tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the row in global memory is wider than the portion to load (e.g. loading only nope from a nope+rope row).
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode.
  • l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ctx (DeviceContext): CUDA device context for TMA descriptor creation.
  • device_buf (DeviceBuffer): Device buffer containing the 2D row-major tensor data.
  • num_rows (Int): Total number of rows in the tensor.

Returns:

TMATensorTile: A TMATensorTile configured for gather4 with the appropriate box width.

Raises:

If TMA descriptor creation fails.

create_tma_tile_gather4[dtype: DType, *, tile_height: Int = 4, tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], ptr.origin], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(tile_height, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple()), IndexList(1, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple())]

Creates a TMATensorTile for gather4 from a raw pointer with automatic box-width computation.

The TMA box width is derived from the swizzle mode. For SWIZZLE_NONE the box width equals tile_width.

Parameters:

  • dtype (DType): The element data type.
  • tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • tile_width (Int): Number of elements per row to load (box width).
  • tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the row in global memory is wider than the portion to load.
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode.
  • l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ctx (DeviceContext): CUDA device context for TMA descriptor creation.
  • ptr (UnsafePointer): Raw device pointer to the 2D row-major tensor data.
  • num_rows (Int): Total number of rows in the tensor.

Returns:

TMATensorTile: A TMATensorTile configured for gather4 with the appropriate box width.

Raises:

If TMA descriptor creation fails.

Was this page helpful?