Skip to main content

Mojo struct

RaggedTensorMap

struct RaggedTensorMap[descriptor_rank: Int, //, dtype: DType, descriptor_shape: IndexList[descriptor_rank], remaining_global_dim_rank: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]

Creates a TMA descriptor that can handle stores with varying lengths. This struct is mainly used for MHA, where sequence lengths may vary between sample.

This struct only supports one dimension being ragged. The continous dimension (where stride is 1) cannot be ragged.

Parameters

  • descriptor_rank (Int): The rank of the descriptor shape (inferred).
  • dtype (DType): The data type of the tensor.
  • descriptor_shape (IndexList): The shape of the shared memory descriptor.
  • remaining_global_dim_rank (Int): The rank of the remaining global tensor dimensions.
  • swizzle_mode (TensorMapSwizzle): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. Defaults to SWIZZLE_NONE.

Fields

  • descriptor (TMADescriptor): The TMA descriptor that will be used to store the ragged tensor.
  • max_length (Int): The maximum length present in the sequences of the ragged tensor.
  • global_shape (IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]): The shape of the global tensor.
  • global_stride (IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]): The stride of the global tensor.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, UnknownDestructibility

Aliases

__copyinit__is_trivial

comptime __copyinit__is_trivial = False

__del__is_trivial

comptime __del__is_trivial = True

device_type

comptime device_type = RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]

The TensorMapDescriptorArray type

global_rank

comptime global_rank = (remaining_global_dim_rank + 3)

The rank of the global tensor.

ragged_descriptor_shape

comptime ragged_descriptor_shape = RaggedTensorMap._descriptor_shape[descriptor_rank, dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]()

The shape of the descriptor that will tile and load from shared -> global memory.

Methods

__init__

__init__(out self, ctx: DeviceContext, global_ptr: LegacyUnsafePointer[Scalar[dtype]], max_length: Int, ragged_stride: Int, batch_size: Int, global_last_dim: Int, remaining_global_dims: IndexList[remaining_global_dim_rank], remaining_global_stride: IndexList[remaining_global_dim_rank])

Initializes a TensorMapDescriptorArray with descriptors for all power-of-2 lengths.

This constructor creates a complete set of TMA descriptors, one for each power of 2 from 1 up to max_descriptor_length. Each descriptor is configured to handle a different first dimension size (1, 2, 4, 8, ..., max_descriptor_length) while maintaining the same remaining tile shape specified by desc_remaining_tile_shape.

Args:

  • ctx (DeviceContext): The device context used to create the TMA descriptors.
  • global_ptr (LegacyUnsafePointer): The source tensor in global memory that will be accessed using the descriptors.
  • max_length (Int): The maximum length present in the sequences of the ragged tensor.
  • ragged_stride (Int): The stride of the ragged dimension in the global tensor.
  • batch_size (Int): The total number of sequences in the ragged tensor.
  • global_last_dim (Int): The last dimension of the global tensor.
  • remaining_global_dims (IndexList): The dimensions of the remaining global tensor.
  • remaining_global_stride (IndexList): The stride of the remaining global tensor.

Raises:

If the operation fails.

get_type_name

static get_type_name() -> String

Returns a string representation of the TensorMapDescriptorArray type.

Returns:

String: A string containing the type name with all template parameters.

get_device_type_name

static get_device_type_name() -> String

Returns the device type name for this descriptor array.

Returns:

String: A string containing the type name with all template parameters.

store_ragged_tile

store_ragged_tile[rank: Int, //, using_max_descriptor_size: Bool = False](self, coordinates: IndexList[rank], preceding_cumulative_length: Int, store_length: Int, mut tile_iterator: LayoutTensorIter[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])

Stores a ragged tile from shared memory to global memory.

Parameters:

  • rank (Int): The rank of the coordinates.
  • using_max_descriptor_size (Bool): If True, optimizes the store around the max descriptor size.

Args:

  • coordinates (IndexList): The starting coordinates of all dimensions except the ragged dimension.
  • preceding_cumulative_length (Int): The cumulative length of the preceding sequences.
  • store_length (Int): The length of the current sequence to be stored.
  • tile_iterator (LayoutTensorIter): The iterator over the tile in shared memory.

prefetch_descriptor

prefetch_descriptor(self)

Prefetches the TMA descriptor into cache.

Was this page helpful?