Skip to main content

Mojo function

create_tma_tile_gather4

create_tma_tile_gather4[dtype: DType, row_width: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](ctx: DeviceContext, device_buf: DeviceBuffer[dtype], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(VariadicList(4, row_width), Tuple()), IndexList(VariadicList(1, row_width), Tuple())]

Creates a TMATensorTile configured for gather4 operations.

Gather4 loads 4 non-contiguous rows from a 2D tensor using a single TMA instruction (SM100/Blackwell). The descriptor's box shape (desc_shape) is (1, row_width) because the hardware requires one row per tile — gather4 internally issues 4 single-row copies at 4 different row indices. The tile shape is (4, row_width) to reflect the total shared memory footprint of the 4 gathered rows.

Parameters:

  • dtype (DType): The element data type of the tensor.
  • row_width (Int): Number of elements per row (innermost dimension).
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE. Use SWIZZLE_64B or SWIZZLE_128B for bandwidth-optimized access (e.g. BF16 RoPE data in FlashMLA).

Args:

  • ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.
  • device_buf (DeviceBuffer): Device buffer containing the 2D tensor data in row-major layout.
  • num_rows (Int): Total number of rows in the tensor (outermost dimension).

Returns:

TMATensorTile: A TMATensorTile with tile_shape=(4, row_width) and desc_shape=(1, row_width), configured for use with async_copy_gather4.

Raises:

If TMA descriptor creation fails.

create_tma_tile_gather4[dtype: DType, row_width: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], ptr.origin], num_rows: Int) -> TMATensorTile[dtype, 2, IndexList(VariadicList(4, row_width), Tuple()), IndexList(VariadicList(1, row_width), Tuple())]

Creates a TMATensorTile configured for gather4 operations from a raw pointer.

This overload accepts a raw device pointer instead of a DeviceBuffer, matching the pattern used by create_split_tma and the KV cache TMA tile factories. A non-owning DeviceBuffer is constructed internally.

Parameters:

  • dtype (DType): The element data type of the tensor.
  • row_width (Int): Number of elements per row (innermost dimension).
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE. Use SWIZZLE_64B or SWIZZLE_128B for bandwidth-optimized access (e.g. BF16 RoPE data in FlashMLA).

Args:

  • ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.
  • ptr (UnsafePointer): Raw device pointer to the 2D tensor data in row-major layout.
  • num_rows (Int): Total number of rows in the tensor (outermost dimension).

Returns:

TMATensorTile: A TMATensorTile with tile_shape=(4, row_width) and desc_shape=(1, row_width), configured for use with async_copy_gather4.

Raises:

If TMA descriptor creation fails.

Was this page helpful?