Skip to main content

Mojo struct

TMAStoreCoords

@register_passable(trivial) struct TMAStoreCoords[BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, c_smem_shape0: Int, stage: Int]

Compute TMA store coordinates and warp election for SM100 epilogue.

Encapsulates the complex coordinate computation logic for TMA stores, including cta_group-specific branching and warp election.

Template Parameters: BM: Block M dimension. BN: Block N dimension. MMA_M: MMA M dimension. MMA_N: MMA N dimension. stageN: Stage width in elements. cta_group: Number of CTAs cooperating (1 or 2). c_smem_shape0: Shape[0] of shared memory tile layout. stage: Current output stage index.

Fields

  • coord_m (UInt):
  • coord_n (UInt):
  • elect_one_warp (Bool):
  • c_smem_coord_m (UInt):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

CG1_TMA_BM

comptime CG1_TMA_BM = c_smem_shape0

CG2_TMA_BM

comptime CG2_TMA_BM = c_smem_shape0 if (MMA_M == 256) else BM

stage_n_offset

comptime stage_n_offset = (stage * stageN)

TMA_BM

comptime TMA_BM = c_smem_shape0 if (MMA_M == 256) else BM if (cta_group == 2) else TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, c_smem_shape0, stage].CG1_TMA_BM

Methods

__init__

__init__(c_coord: Tuple[UInt32, UInt32], warp_id: UInt32) -> Self

Compute all TMA store coordinates.

Args:

  • c_coord (Tuple): Output tile coordinates (m_tile, n_tile).
  • warp_id (UInt32): Current warp ID.

Was this page helpful?