Mojo struct
RaggedTensorMap
struct RaggedTensorMap[descriptor_rank: Int, //, dtype: DType, descriptor_shape: IndexList[descriptor_rank], remaining_global_dim_rank: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]
Creates a TMA descriptor that can handle stores with varying lengths. This struct is mainly used for MHA, where sequence lengths may vary between sample.
This struct only supports one dimension being ragged. The continous dimension (where stride is 1) cannot be ragged.
Parameters
- descriptor_rank (
Int): The rank of the descriptor shape (inferred). - dtype (
DType): The data type of the tensor. - descriptor_shape (
IndexList): The shape of the shared memory descriptor. - remaining_global_dim_rank (
Int): The rank of the remaining global tensor dimensions. - swizzle_mode (
TensorMapSwizzle): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. Defaults to SWIZZLE_NONE.
Fields
- descriptor (
TMADescriptor): The TMA descriptor that will be used to store the ragged tensor. - max_length (
Int): The maximum length present in the sequences of the ragged tensor. - global_shape (
IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]): The shape of the global tensor. - global_stride (
IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]): The stride of the global tensor.
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
comptime __copyinit__is_trivial = False
__del__is_trivial
comptime __del__is_trivial = True
device_type
comptime device_type = RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]
The TensorMapDescriptorArray type
global_rank
comptime global_rank = (remaining_global_dim_rank + 3)
The rank of the global tensor.
ragged_descriptor_shape
comptime ragged_descriptor_shape = RaggedTensorMap._descriptor_shape[descriptor_rank, dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]()
The shape of the descriptor that will tile and load from shared -> global memory.
Methods
__init__
__init__(out self, ctx: DeviceContext, global_ptr: LegacyUnsafePointer[Scalar[dtype]], max_length: Int, ragged_stride: Int, batch_size: Int, global_last_dim: Int, remaining_global_dims: IndexList[remaining_global_dim_rank], remaining_global_stride: IndexList[remaining_global_dim_rank])
Initializes a TensorMapDescriptorArray with descriptors for all power-of-2 lengths.
This constructor creates a complete set of TMA descriptors, one for each power of 2 from 1 up to max_descriptor_length. Each descriptor is configured to handle a different first dimension size (1, 2, 4, 8, ..., max_descriptor_length) while maintaining the same remaining tile shape specified by desc_remaining_tile_shape.
Args:
- ctx (
DeviceContext): The device context used to create the TMA descriptors. - global_ptr (
LegacyUnsafePointer): The source tensor in global memory that will be accessed using the descriptors. - max_length (
Int): The maximum length present in the sequences of the ragged tensor. - ragged_stride (
Int): The stride of the ragged dimension in the global tensor. - batch_size (
Int): The total number of sequences in the ragged tensor. - global_last_dim (
Int): The last dimension of the global tensor. - remaining_global_dims (
IndexList): The dimensions of the remaining global tensor. - remaining_global_stride (
IndexList): The stride of the remaining global tensor.
Raises:
If the operation fails.
get_type_name
static get_type_name() -> String
Returns a string representation of the TensorMapDescriptorArray type.
Returns:
String: A string containing the type name with all template parameters.
get_device_type_name
static get_device_type_name() -> String
Returns the device type name for this descriptor array.
Returns:
String: A string containing the type name with all template parameters.
store_ragged_tile
store_ragged_tile[rank: Int, //, using_max_descriptor_size: Bool = False](self, coordinates: IndexList[rank], preceding_cumulative_length: Int, store_length: Int, mut tile_iterator: LayoutTensorIter[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])
Stores a ragged tile from shared memory to global memory.
Parameters:
- rank (
Int): The rank of the coordinates. - using_max_descriptor_size (
Bool): If True, optimizes the store around the max descriptor size.
Args:
- coordinates (
IndexList): The starting coordinates of all dimensions except the ragged dimension. - preceding_cumulative_length (
Int): The cumulative length of the preceding sequences. - store_length (
Int): The length of the current sequence to be stored. - tile_iterator (
LayoutTensorIter): The iterator over the tile in shared memory.
prefetch_descriptor
prefetch_descriptor(self)
Prefetches the TMA descriptor into cache.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!