Mojo function
create_tma_tile_gather4
create_tma_tile_gather4[dtype: DType, row_width: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](ctx: DeviceContext, device_buf: DeviceBuffer[dtype], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(VariadicList(4, row_width), Tuple()), IndexList(VariadicList(1, row_width), Tuple())]
Creates a TMATensorTile configured for gather4 operations.
Gather4 loads 4 non-contiguous rows from a 2D tensor using a single TMA instruction (SM100/Blackwell). The descriptor's box shape (desc_shape) is (1, row_width) because the hardware requires one row per tile — gather4 internally issues 4 single-row copies at 4 different row indices. The tile shape is (4, row_width) to reflect the total shared memory footprint of the 4 gathered rows.
Parameters:
- dtype (
DType): The element data type of the tensor. - row_width (
Int): Number of elements per row (innermost dimension). - swizzle_mode (
TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE. Use SWIZZLE_64B or SWIZZLE_128B for bandwidth-optimized access (e.g. BF16 RoPE data in FlashMLA).
Args:
- ctx (
DeviceContext): The CUDA device context used to create the TMA descriptor. - device_buf (
DeviceBuffer): Device buffer containing the 2D tensor data in row-major layout. - num_rows (
Int): Total number of rows in the tensor (outermost dimension).
Returns:
TMATensorTile: A TMATensorTile with tile_shape=(4, row_width) and
desc_shape=(1, row_width), configured for use with
async_copy_gather4.
Raises:
If TMA descriptor creation fails.
create_tma_tile_gather4[dtype: DType, row_width: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], ptr.origin], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(VariadicList(4, row_width), Tuple()), IndexList(VariadicList(1, row_width), Tuple())]
Creates a TMATensorTile configured for gather4 operations from a raw pointer.
This overload accepts a raw device pointer instead of a DeviceBuffer,
matching the pattern used by create_split_tma and the KV cache TMA
tile factories. A non-owning DeviceBuffer is constructed internally.
Parameters:
- dtype (
DType): The element data type of the tensor. - row_width (
Int): Number of elements per row (innermost dimension). - swizzle_mode (
TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE. Use SWIZZLE_64B or SWIZZLE_128B for bandwidth-optimized access (e.g. BF16 RoPE data in FlashMLA).
Args:
- ctx (
DeviceContext): The CUDA device context used to create the TMA descriptor. - ptr (
UnsafePointer): Raw device pointer to the 2D tensor data in row-major layout. - num_rows (
Int): Total number of rows in the tensor (outermost dimension).
Returns:
TMATensorTile: A TMATensorTile with tile_shape=(4, row_width) and
desc_shape=(1, row_width), configured for use with
async_copy_gather4.
Raises:
If TMA descriptor creation fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!