Mojo struct
TensorMapArray
struct TensorMapArray[rank: Int, //, dtype: DType, desc_remaining_tile_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, max_descriptor_length: Int = 256]
An array of TMA descriptors for efficient multi-descriptor management.
This struct maintains multiple TMA (Tensor Memory Access) descriptors organized in a power-of-2 indexed array structure. It enables efficient selection and use of different descriptor configurations at runtime, which is particularly useful for handling varying tensor dimensions or batch sizes in GPU operations.
The array uses a logarithmic indexing scheme where descriptors are stored at positions corresponding to powers of 2 (1, 2, 4, 8, 16, ..., up to max_descriptor_length). This allows for efficient lookup and memory usage while supporting a wide range of descriptor configurations.
Constraints:
- The rank must be 1 or 2 for descriptor shape construction.
- When swizzling is enabled, tile dimensions must comply with swizzle mode byte limits.
- max_descriptor_length should be a reasonable power of 2 to optimize memory usage.
Parameters
- rank (
Int): The rank (number of dimensions) of the tensors that will be accessed. Currently supports rank 1 or 2. - dtype (
DType): The data type of the tensor elements that will be transferred. - desc_remaining_tile_shape (
IndexList): All dims of the descriptor shape except the first dimension. - swizzle_mode (
TensorMapSwizzle): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. Defaults to SWIZZLE_NONE. - max_descriptor_length (
Int): The maximum first dimension size supported by the descriptor array. The array will contain descriptors for all powers of 2 up to this value. Defaults to 256.
Fields
- descriptor_array (
InlineArray[TMADescriptor, TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length].arr_size]): The array of TMA descriptors.
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = False
__del__is_trivial
alias __del__is_trivial = False
arr_size
alias arr_size = (Int.__init__[Float32](log2(SIMD[DType.float32, 1](max_descriptor_length))) + 1)
How many descriptors are in the array.
desc_length_list
alias desc_length_list = TensorMapArray._desc_length_list[rank, dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length]()
The list of descriptor lengths in ascending order.
desc_length_list_reverse
alias desc_length_list_reverse = TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length].desc_length_list.reverse[TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length].arr_size, DType.int64]()
The list of descriptor lengths in descending order.
device_type
alias device_type = TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length]
The TensorMapDescriptorArray type
Methods
__init__
__init__(out self, ctx: DeviceContext, global_tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Initializes a TensorMapDescriptorArray with descriptors for all power-of-2 lengths.
This constructor creates a complete set of TMA descriptors, one for each power of 2 from 1 up to max_descriptor_length. Each descriptor is configured to handle a different first dimension size (1, 2, 4, 8, ..., max_descriptor_length) while maintaining the same remaining tile shape specified by desc_remaining_tile_shape.
Constraints:
- max_descriptor_length must be a power of two.
- max_descriptor_length must be less than or equal to 256.
Args:
- ctx (
DeviceContext): The device context used to create the TMA descriptors. - global_tensor (
LayoutTensor): The source tensor in global memory that will be accessed using these descriptors. This defines the global memory layout and data type.
get_type_name
static get_type_name() -> String
Returns a string representation of the TensorMapDescriptorArray type.
Returns:
String: A string containing the type name with all template parameters.
calculate_dim_repeat
static calculate_dim_repeat(sequence_length: Int) -> Int
Returns the number of times the descriptor length fits into the sequence length.
Args:
- sequence_length (
Int): The length of the sequence to be transferred.
Returns:
Int: The number of times the 1 dim descriptor fits into the sequence length.
get_device_type_name
static get_device_type_name() -> String
Returns the device type name for this descriptor array.
Returns:
String: A string containing the type name with all template parameters.
store_ragged_tile
store_ragged_tile(self, rows_to_copy: Int, start_coord: IndexList[(rank + 1)], src: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin])
Stores a ragged tile from shared memory to global memory using multiple TMA descriptors.
This method efficiently handles non-power-of-2 row counts by decomposing the transfer into multiple operations using the largest possible descriptors. It uses a greedy algorithm to select descriptors in descending order (largest first), minimizing the number of individual TMA operations required.
For example, transferring 13 rows would use descriptors for 8 + 4 + 1 rows.
Args:
- rows_to_copy (
Int): The total number of rows to transfer from shared memory to global memory. - start_coord (
IndexList): The starting coordinate in global memory where the transfer should begin. Must have rank+1 dimensions (includes the batch/first dimension). - src (
LegacyUnsafePointer): The source pointer in shared memory containing the data to be transferred. Must be in the SHARED address space.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!