Mojo struct
TMATensorTileArray
struct TMATensorTileArray[num_of_tensormaps: Int, dtype: DType, rank: Int, cta_tile_shape: IndexList[rank], desc_shape: IndexList[rank]]
An array of TMA descriptors.
Parameters
- num_of_tensormaps (
Int): Int The number of TMA descriptors aka tensor map. - dtype (
DType): DType The data type of the tensor elements. - rank (
Int): Int The dimensionality of the tile (2, 3, 4, or 5). - cta_tile_shape (
IndexList): IndexList[rank] The shape of the CTA tile in shared memory. - desc_shape (
IndexList): IndexList[rank] The shape of the descriptor, which can be different from the tile shape to accommodate hardware requirements like WGMMA.
Fields
-
tensormaps_ptr (
UnsafePointer[UInt8, MutAnyOrigin]): A static tuple of pointers to TMA descriptors. This field stores an array of pointers toTMATensorTileinstances, where each pointer references a TMA descriptor in device memory. The array has a fixed size determined by the num_of_tensormaps parameter.The TMA descriptors are used by the GPU hardware to efficiently transfer data between global and shared memory with specific memory access patterns defined by the layouts.
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
descriptor_bytes
comptime descriptor_bytes = 128
Size of the TMA descriptor in bytes.
This is a constant value that represents the size of the TMA descriptor in bytes. It is used to calculate the offset of the TMA descriptor in the device memory.
device_type
comptime device_type = TMATensorTileArray[num_of_tensormaps, dtype, rank, cta_tile_shape, desc_shape]
The device-side type representation.
Methods
__init__
__init__(tensormaps_device: DeviceBuffer[DType.uint8]) -> Self
Initializes a new TMATensorTileArray.
Args:
- tensormaps_device (
DeviceBuffer): Device buffer to store TMA descriptors.
__getitem__
__getitem__(self, index: Int) -> UnsafePointer[TMATensorTile[dtype, rank, cta_tile_shape, desc_shape], MutAnyOrigin]
Retrieve a TMA descriptor.
Args:
- index (
Int): Index of the TMA descriptor.
Returns:
UnsafePointer: UnsafePointer to the TMATensorTile at the specified index.
get_type_name
static get_type_name() -> String
Gets this type's name, for use in error messages when handing arguments to kernels.
Returns:
String: This type's name.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!