Skip to main content
Log in

Mojo struct

TMATensorTileArray

@register_passable(trivial) struct TMATensorTileArray[num_of_tensormaps: Int, dtype: DType, cta_tile_layout: Layout, desc_layout: Layout]

An array of TMA descripotr.

Parameters

  • num_of_tensormaps (Int): Int The number of TMA descriptors aka tensor map.
  • dtype (DType): DType The data type of the tensor elements.
  • cta_tile_layout (Layout): Layout The layout of the tile in shared memory, typically specified as row_major.
  • desc_layout (Layout): Layout The layout of the descriptor, which can be different from the shared memory layout to accommodate hardware requirements like WGMMA.

Fields

  • tensormaps (StaticTuple[UnsafePointer[TMATensorTile[dtype, cta_tile_layout, desc_layout]], num_of_tensormaps]):

Implemented traits

AnyType, Copyable, ExplicitlyCopyable, Movable, UnknownDestructibility

Methods

__init__

__init__(out self, ctx: DeviceContext, tensormaps_device: DeviceBuffer[uint8], template_tma_tensormap: Optional[TMATensorTile[dtype, cta_tile_layout, desc_layout]])

Initializes a new TMATensorTileArray.

Args:

  • ctx (DeviceContext): Device context.
  • tensormaps_device (DeviceBuffer[uint8]): Device buffer to store TMA descriptors.
  • template_tma_tensormap (Optional[TMATensorTile[dtype, cta_tile_layout, desc_layout]]): TMA desctripor tempalate.

__getitem__

__getitem__(self, index: Int) -> UnsafePointer[TMATensorTile[dtype, cta_tile_layout, desc_layout]]

Retrieve a TMA descriptor.

Args:

  • index (Int): Index of the TMA descritpor.