Mojo struct
TMATensorTileArray
@register_passable(trivial)
struct TMATensorTileArray[num_of_tensormaps: Int, dtype: DType, cta_tile_layout: Layout, desc_layout: Layout]
An array of TMA descripotr.
Parameters
- num_of_tensormaps (
Int
): Int The number of TMA descriptors aka tensor map. - dtype (
DType
): DType The data type of the tensor elements. - cta_tile_layout (
Layout
): Layout The layout of the tile in shared memory, typically specified as row_major. - desc_layout (
Layout
): Layout The layout of the descriptor, which can be different from the shared memory layout to accommodate hardware requirements like WGMMA.
Fields
- tensormaps (
StaticTuple[UnsafePointer[TMATensorTile[dtype, cta_tile_layout, desc_layout]], num_of_tensormaps]
):
Implemented traits
AnyType
,
Copyable
,
ExplicitlyCopyable
,
Movable
,
UnknownDestructibility
Methods
__init__
__init__(out self, ctx: DeviceContext, tensormaps_device: DeviceBuffer[uint8], template_tma_tensormap: Optional[TMATensorTile[dtype, cta_tile_layout, desc_layout]])
Initializes a new TMATensorTileArray.
Args:
- ctx (
DeviceContext
): Device context. - tensormaps_device (
DeviceBuffer[uint8]
): Device buffer to store TMA descriptors. - template_tma_tensormap (
Optional[TMATensorTile[dtype, cta_tile_layout, desc_layout]]
): TMA desctripor tempalate.
__getitem__
__getitem__(self, index: Int) -> UnsafePointer[TMATensorTile[dtype, cta_tile_layout, desc_layout]]
Retrieve a TMA descriptor.
Args:
- index (
Int
): Index of the TMA descritpor.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!