Mojo function
create_tma_tile_gather4
create_tma_tile_gather4[dtype: DType, *, tile_height: Int = 4, tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](ctx: DeviceContext, device_buf: DeviceBuffer[dtype], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(tile_height, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple()), IndexList(1, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple())]
Creates a TMATensorTile for gather4 with automatic box-width computation.
The global tensor has tile_width elements per row. The TMA box
width is derived from the swizzle mode so that each gather4 call loads one
swizzle-group-sized column chunk (for SWIZZLE_NONE the box equals the full
row). The caller iterates over column groups using the col_idx
parameter of async_copy_gather4::
for cg in range(_gather4_num_col_groups[dtype, tile_width, swizzle_mode]()):
tile.async_copy_gather4(dst, bar, col_idx=Int32(cg * box_width),
row0, row1, row2, row3)Alternatively, use async_copy_gather4_tile to load the full
tile_height-row tile in one call (it loops over 4-row chunks and
column groups internally).
Parameters:
- dtype (
DType): The element data type. - tile_height (
Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility. - tile_width (
Int): Number of elements per row to load (box width). - tile_stride (
Int): Row stride in elements in global memory. Defaults totile_width. Use a larger value when the row in global memory is wider than the portion to load (e.g. loading only nope from a nope+rope row). - swizzle_mode (
TensorMapSwizzle): TMA swizzle mode. - l2_promotion (
TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.
Args:
- ctx (
DeviceContext): CUDA device context for TMA descriptor creation. - device_buf (
DeviceBuffer): Device buffer containing the 2D row-major tensor data. - num_rows (
Int): Total number of rows in the tensor.
Returns:
TMATensorTile: A TMATensorTile configured for gather4 with the appropriate box width.
Raises:
If TMA descriptor creation fails.
create_tma_tile_gather4[dtype: DType, *, tile_height: Int = 4, tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], ptr.origin], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(tile_height, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple()), IndexList(1, _gather4_box_width[dtype, tile_width, swizzle_mode](), __list_literal__=Tuple())]
Creates a TMATensorTile for gather4 from a raw pointer with automatic box-width computation.
The TMA box width is derived from the swizzle mode. For SWIZZLE_NONE the
box width equals tile_width.
Parameters:
- dtype (
DType): The element data type. - tile_height (
Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility. - tile_width (
Int): Number of elements per row to load (box width). - tile_stride (
Int): Row stride in elements in global memory. Defaults totile_width. Use a larger value when the row in global memory is wider than the portion to load. - swizzle_mode (
TensorMapSwizzle): TMA swizzle mode. - l2_promotion (
TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.
Args:
- ctx (
DeviceContext): CUDA device context for TMA descriptor creation. - ptr (
UnsafePointer): Raw device pointer to the 2D row-major tensor data. - num_rows (
Int): Total number of rows in the tensor.
Returns:
TMATensorTile: A TMATensorTile configured for gather4 with the appropriate box width.
Raises:
If TMA descriptor creation fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!