Mojo struct
TileLoaderTMA
@register_passable(trivial)
struct TileLoaderTMA[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, a_type: DType, b_type: DType, a_layout: Layout, b_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, //, a_smem_layout: Layout, b_smem_layout: Layout, BM: Int, BN: Int, BK: Int, MMA_N: Int, cta_group: Int, k_group_size: Int, num_pipeline_stages: Int, num_group_stages: Int]
TMA-based tile loader for SM100.
Encapsulates the complete tile loading logic including:
- K-group batching (multiple tiles per barrier)
- CTA group coordination (1-SM or 2-SM cooperative loading)
- Peer CTA slicing for 2-SM MMA
- expect_bytes management
Template Parameters: a_tma_origin: Origin type for A TMA pointer. b_tma_origin: Origin type for B TMA pointer. a_type: Data type for A matrix. b_type: Data type for B matrix. a_layout: Global memory layout for A. b_layout: Global memory layout for B. a_desc_layout: TMA descriptor layout for A. b_desc_layout: TMA descriptor layout for B. a_smem_layout: Shared memory tile layout for A. b_smem_layout: Shared memory tile layout for B. BM: Block tile M dimension. BN: Block tile N dimension. BK: Block tile K dimension. MMA_N: MMA N dimension for B coordinate calculation. cta_group: Number of CTAs cooperating, 1 or 2. k_group_size: Number of K tiles per barrier sync. num_pipeline_stages: Total pipeline stages. num_group_stages: Pipeline stages / k_group_size.
Fields
- a_tma_op (
TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOpPtr): - b_tma_op (
TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOpPtr): - a_multicast_mask (
UInt16): - b_multicast_mask (
UInt16): - peer_rank_n (
UInt): - peer_rank_m (
UInt): - peer_m_rank (
UInt): - work_m_coord (
UInt): - work_n_coord (
UInt):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
a_expected_bytes
comptime a_expected_bytes = (a_smem_layout.size() * size_of[a_type]())
a_tma_load_size
comptime a_tma_load_size = a_desc_layout.size()
a_tma_rows
comptime a_tma_rows = a_desc_layout.shape[0].value()
ATmaOp
comptime ATmaOp = TMATensorTile[a_type, a_layout, a_desc_layout]
ATmaOpPtr
comptime ATmaOpPtr = Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOp, a_tma_origin]
b_expected_bytes
comptime b_expected_bytes = (b_smem_layout.size() * size_of[b_type]())
b_tma_load_size
comptime b_tma_load_size = b_desc_layout.size()
b_tma_rows
comptime b_tma_rows = b_desc_layout.shape[0].value()
BTmaOp
comptime BTmaOp = TMATensorTile[b_type, b_layout, b_desc_layout]
BTmaOpPtr
comptime BTmaOpPtr = Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOp, b_tma_origin]
expected_bytes
comptime expected_bytes = ((cta_group * (TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].a_expected_bytes + TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].b_expected_bytes)) * k_group_size)
Methods
__init__
__init__(a_tma_op: Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOp, a_tma_origin], b_tma_op: Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOp, b_tma_origin], a_multicast_mask: UInt16, b_multicast_mask: UInt16, peer_cta_coord: Tuple[UInt, UInt, UInt]) -> Self
Initialize the TMA tile loader.
Args:
- a_tma_op (
Pointer): Pointer to A matrix TMA descriptor. - b_tma_op (
Pointer): Pointer to B matrix TMA descriptor. - a_multicast_mask (
UInt16): Multicast mask for A tiles. - b_multicast_mask (
UInt16): Multicast mask for B tiles. - peer_cta_coord (
Tuple): Peer CTA coordinates (rank_n, rank_m, peer_m_rank).
set_work_tile
set_work_tile(mut self, m_coord: UInt, n_coord: UInt)
Set the current output tile coordinates.
Args:
load_tiles
load_tiles[tiles_origin: MutOrigin, //](self, tiles: ProducerTiles[tiles_origin, a_type, b_type, a_smem_layout, b_smem_layout, num_pipeline_stages, num_group_stages, k_group_size], iter_idx: UInt32, elect_one_cta: Bool)
Load k_group_size A and B tiles using TMA.
Args:
- tiles (
ProducerTiles): ProducerTiles context with stage, barrier, and tile arrays. - iter_idx (
UInt32): K iteration index (base index, not multiplied). - elect_one_cta (
Bool): True if this CTA should call expect_bytes.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!