Skip to main content

Mojo struct

TensorMapArray

struct TensorMapArray[rank: Int, //, dtype: DType, desc_remaining_tile_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, max_descriptor_length: Int = 256]

An array of TMA descriptors for efficient multi-descriptor management.

This struct maintains multiple TMA (Tensor Memory Access) descriptors organized in a power-of-2 indexed array structure. It enables efficient selection and use of different descriptor configurations at runtime, which is particularly useful for handling varying tensor dimensions or batch sizes in GPU operations.

The array uses a logarithmic indexing scheme where descriptors are stored at positions corresponding to powers of 2 (1, 2, 4, 8, 16, ..., up to max_descriptor_length). This allows for efficient lookup and memory usage while supporting a wide range of descriptor configurations.

Constraints:

  • The rank must be 1 or 2 for descriptor shape construction.
  • When swizzling is enabled, tile dimensions must comply with swizzle mode byte limits.
  • max_descriptor_length should be a reasonable power of 2 to optimize memory usage.

Parameters

  • rank (Int): The rank (number of dimensions) of the tensors that will be accessed. Currently supports rank 1 or 2.
  • dtype (DType): The data type of the tensor elements that will be transferred.
  • desc_remaining_tile_shape (IndexList): All dims of the descriptor shape except the first dimension.
  • swizzle_mode (TensorMapSwizzle): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. Defaults to SWIZZLE_NONE.
  • max_descriptor_length (Int): The maximum first dimension size supported by the descriptor array. The array will contain descriptors for all powers of 2 up to this value. Defaults to 256.

Fields

  • descriptor_array (InlineArray[TMADescriptor, TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length].arr_size]): The array of TMA descriptors.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = False

__del__is_trivial

alias __del__is_trivial = False

arr_size

alias arr_size = (Int.__init__[Float32](log2(SIMD[DType.float32, 1](max_descriptor_length))) + 1)

How many descriptors are in the array.

desc_length_list

alias desc_length_list = TensorMapArray._desc_length_list[rank, dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length]()

The list of descriptor lengths in ascending order.

desc_length_list_reverse

alias desc_length_list_reverse = TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length].desc_length_list.reverse[TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length].arr_size, DType.int64]()

The list of descriptor lengths in descending order.

device_type

alias device_type = TensorMapArray[dtype, desc_remaining_tile_shape, swizzle_mode, max_descriptor_length]

The TensorMapDescriptorArray type

Methods

__init__

__init__(out self, ctx: DeviceContext, global_tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Initializes a TensorMapDescriptorArray with descriptors for all power-of-2 lengths.

This constructor creates a complete set of TMA descriptors, one for each power of 2 from 1 up to max_descriptor_length. Each descriptor is configured to handle a different first dimension size (1, 2, 4, 8, ..., max_descriptor_length) while maintaining the same remaining tile shape specified by desc_remaining_tile_shape.

Constraints:

  • max_descriptor_length must be a power of two.
  • max_descriptor_length must be less than or equal to 256.

Args:

  • ctx (DeviceContext): The device context used to create the TMA descriptors.
  • global_tensor (LayoutTensor): The source tensor in global memory that will be accessed using these descriptors. This defines the global memory layout and data type.

get_type_name

static get_type_name() -> String

Returns a string representation of the TensorMapDescriptorArray type.

Returns:

String: A string containing the type name with all template parameters.

calculate_dim_repeat

static calculate_dim_repeat(sequence_length: Int) -> Int

Returns the number of times the descriptor length fits into the sequence length.

Args:

  • sequence_length (Int): The length of the sequence to be transferred.

Returns:

Int: The number of times the 1 dim descriptor fits into the sequence length.

get_device_type_name

static get_device_type_name() -> String

Returns the device type name for this descriptor array.

Returns:

String: A string containing the type name with all template parameters.

store_ragged_tile

store_ragged_tile(self, rows_to_copy: Int, start_coord: IndexList[(rank + 1)], src: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin])

Stores a ragged tile from shared memory to global memory using multiple TMA descriptors.

This method efficiently handles non-power-of-2 row counts by decomposing the transfer into multiple operations using the largest possible descriptors. It uses a greedy algorithm to select descriptors in descending order (largest first), minimizing the number of individual TMA operations required.

For example, transferring 13 rows would use descriptors for 8 + 4 + 1 rows.

Args:

  • rows_to_copy (Int): The total number of rows to transfer from shared memory to global memory.
  • start_coord (IndexList): The starting coordinate in global memory where the transfer should begin. Must have rank+1 dimensions (includes the batch/first dimension).
  • src (LegacyUnsafePointer): The source pointer in shared memory containing the data to be transferred. Must be in the SHARED address space.

Was this page helpful?