Skip to main content

Mojo struct

TMATensorTileArray

struct TMATensorTileArray[num_of_tensormaps: Int, dtype: DType, rank: Int, cta_tile_shape: IndexList[rank], desc_shape: IndexList[rank]]

An array of TMA descriptors.

Parameters

  • num_of_tensormaps (Int): Int The number of TMA descriptors aka tensor map.
  • dtype (DType): DType The data type of the tensor elements.
  • rank (Int): Int The dimensionality of the tile (2, 3, 4, or 5).
  • cta_tile_shape (IndexList): IndexList[rank] The shape of the CTA tile in shared memory.
  • desc_shape (IndexList): IndexList[rank] The shape of the descriptor, which can be different from the tile shape to accommodate hardware requirements like WGMMA.

Fields

  • tensormaps_ptr (UnsafePointer[UInt8, MutAnyOrigin]): A static tuple of pointers to TMA descriptors. This field stores an array of pointers to TMATensorTile instances, where each pointer references a TMA descriptor in device memory. The array has a fixed size determined by the num_of_tensormaps parameter.

    The TMA descriptors are used by the GPU hardware to efficiently transfer data between global and shared memory with specific memory access patterns defined by the layouts.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

descriptor_bytes

comptime descriptor_bytes = 128

Size of the TMA descriptor in bytes.

This is a constant value that represents the size of the TMA descriptor in bytes. It is used to calculate the offset of the TMA descriptor in the device memory.

device_type

comptime device_type = TMATensorTileArray[num_of_tensormaps, dtype, rank, cta_tile_shape, desc_shape]

The device-side type representation.

Methods

__init__

__init__(tensormaps_device: DeviceBuffer[DType.uint8]) -> Self

Initializes a new TMATensorTileArray.

Args:

  • tensormaps_device (DeviceBuffer): Device buffer to store TMA descriptors.

__getitem__

__getitem__(self, index: Int) -> UnsafePointer[TMATensorTile[dtype, rank, cta_tile_shape, desc_shape], MutAnyOrigin]

Retrieve a TMA descriptor.

Args:

  • index (Int): Index of the TMA descriptor.

Returns:

UnsafePointer: UnsafePointer to the TMATensorTile at the specified index.

get_type_name

static get_type_name() -> String

Gets this type's name, for use in error messages when handing arguments to kernels.

Returns:

String: This type's name.

Was this page helpful?