Mojo function
create_split_tma
create_split_tma[rank: Int, dtype: DType, //, smem_shape: IndexList[rank], gmem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle](ctx: DeviceContext, ptr: LegacyUnsafePointer[Scalar[dtype]], runtime_dim0: Int, out res: TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)])
Creates a TMA tensor tile assuming that the first dimension in global memory has UNKNOWN_VALUE.
This function creates a TMATensorTile that optionally splits the last dimension
of the tensor into multiples of swizzle granularity. This functionality is currently
disabled because it was not found to improve performance.
Parameters:
- rank (
Int): The number of dimensions of the tensor. - dtype (
DType): The data type of the tensor elements. - smem_shape (
IndexList): The shape of the tile in shared memory. - gmem_shape (
IndexList): The shape of the global memory tensor. - swizzle_mode (
TensorMapSwizzle): The swizzling mode for memory access optimization.
Args:
- ctx (
DeviceContext): The CUDA device context used to create the TMA descriptor. - ptr (
LegacyUnsafePointer): Pointer to the global memory tensor data. - runtime_dim0 (
Int): The runtime size of the first dimension of the global tensor.
Returns:
TMATensorTile: The resulting TMA tensor tile with split layout.
Raises:
If TMA descriptor creation fails.
create_split_tma[rank: Int, dtype: DType, //, smem_shape: IndexList[rank], gmem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle](ctx: DeviceContext, ptr: LegacyUnsafePointer[Scalar[dtype]], runtime_dim0: Int, runtime_dim1: Int, out res: TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)])
Creates a TMA tensor tile assuming that the first two dimensions in global memory has UNKNOWN_VALUE.
This function creates a TMATensorTile that optionally splits the last dimension
of the tensor into multiples of swizzle granularity. This functionality is currently
disabled because it was not found to improve performance.
Parameters:
- rank (
Int): The number of dimensions of the tensor. - dtype (
DType): The data type of the tensor elements. - smem_shape (
IndexList): The shape of the tile in shared memory. - gmem_shape (
IndexList): The shape of the global memory tensor. - swizzle_mode (
TensorMapSwizzle): The swizzling mode for memory access optimization.
Args:
- ctx (
DeviceContext): The CUDA device context used to create the TMA descriptor. - ptr (
LegacyUnsafePointer): Pointer to the global memory tensor data. - runtime_dim0 (
Int): The runtime size of the first dimension of the global tensor. - runtime_dim1 (
Int): The runtime size of the second dimension of the global tensor.
Returns:
TMATensorTile: The resulting TMA tensor tile with split layout.
Raises:
If TMA descriptor creation fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!