Skip to main content

Mojo struct

TileLoaderTMA

@register_passable(trivial) struct TileLoaderTMA[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, a_type: DType, b_type: DType, a_layout: Layout, b_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, //, a_smem_layout: Layout, b_smem_layout: Layout, BM: Int, BN: Int, BK: Int, MMA_N: Int, cta_group: Int, k_group_size: Int, num_pipeline_stages: Int, num_group_stages: Int]

TMA-based tile loader for SM100.

Encapsulates the complete tile loading logic including:

  • K-group batching (multiple tiles per barrier)
  • CTA group coordination (1-SM or 2-SM cooperative loading)
  • Peer CTA slicing for 2-SM MMA
  • expect_bytes management

Template Parameters: a_tma_origin: Origin type for A TMA pointer. b_tma_origin: Origin type for B TMA pointer. a_type: Data type for A matrix. b_type: Data type for B matrix. a_layout: Global memory layout for A. b_layout: Global memory layout for B. a_desc_layout: TMA descriptor layout for A. b_desc_layout: TMA descriptor layout for B. a_smem_layout: Shared memory tile layout for A. b_smem_layout: Shared memory tile layout for B. BM: Block tile M dimension. BN: Block tile N dimension. BK: Block tile K dimension. MMA_N: MMA N dimension for B coordinate calculation. cta_group: Number of CTAs cooperating, 1 or 2. k_group_size: Number of K tiles per barrier sync. num_pipeline_stages: Total pipeline stages. num_group_stages: Pipeline stages / k_group_size.

Fields

  • a_tma_op (TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOpPtr):
  • b_tma_op (TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOpPtr):
  • a_multicast_mask (UInt16):
  • b_multicast_mask (UInt16):
  • peer_rank_n (UInt):
  • peer_rank_m (UInt):
  • peer_m_rank (UInt):
  • work_m_coord (UInt):
  • work_n_coord (UInt):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

a_expected_bytes

comptime a_expected_bytes = (a_smem_layout.size() * size_of[a_type]())

a_tma_load_size

comptime a_tma_load_size = a_desc_layout.size()

a_tma_rows

comptime a_tma_rows = a_desc_layout.shape[0].value()

ATmaOp

comptime ATmaOp = TMATensorTile[a_type, a_layout, a_desc_layout]

ATmaOpPtr

comptime ATmaOpPtr = Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOp, a_tma_origin]

b_expected_bytes

comptime b_expected_bytes = (b_smem_layout.size() * size_of[b_type]())

b_tma_load_size

comptime b_tma_load_size = b_desc_layout.size()

b_tma_rows

comptime b_tma_rows = b_desc_layout.shape[0].value()

BTmaOp

comptime BTmaOp = TMATensorTile[b_type, b_layout, b_desc_layout]

BTmaOpPtr

comptime BTmaOpPtr = Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOp, b_tma_origin]

expected_bytes

comptime expected_bytes = ((cta_group * (TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].a_expected_bytes + TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].b_expected_bytes)) * k_group_size)

Methods

__init__

__init__(a_tma_op: Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOp, a_tma_origin], b_tma_op: Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOp, b_tma_origin], a_multicast_mask: UInt16, b_multicast_mask: UInt16, peer_cta_coord: Tuple[UInt, UInt, UInt]) -> Self

Initialize the TMA tile loader.

Args:

  • a_tma_op (Pointer): Pointer to A matrix TMA descriptor.
  • b_tma_op (Pointer): Pointer to B matrix TMA descriptor.
  • a_multicast_mask (UInt16): Multicast mask for A tiles.
  • b_multicast_mask (UInt16): Multicast mask for B tiles.
  • peer_cta_coord (Tuple): Peer CTA coordinates (rank_n, rank_m, peer_m_rank).

set_work_tile

set_work_tile(mut self, m_coord: UInt, n_coord: UInt)

Set the current output tile coordinates.

Args:

  • m_coord (UInt): M coordinate of the output tile.
  • n_coord (UInt): N coordinate of the output tile.

load_tiles

load_tiles[tiles_origin: MutOrigin, //](self, tiles: ProducerTiles[tiles_origin, a_type, b_type, a_smem_layout, b_smem_layout, num_pipeline_stages, num_group_stages, k_group_size], iter_idx: UInt32, elect_one_cta: Bool)

Load k_group_size A and B tiles using TMA.

Args:

  • tiles (ProducerTiles): ProducerTiles context with stage, barrier, and tile arrays.
  • iter_idx (UInt32): K iteration index (base index, not multiplied).
  • elect_one_cta (Bool): True if this CTA should call expect_bytes.

Was this page helpful?