Skip to main content

Mojo struct

StMatrixCoords

@register_passable(trivial) struct StMatrixCoords[MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, transpose_c: Bool]

Compute coordinates for st.matrix operations.

Encapsulates the complex coordinate calculations needed for storing accumulator fragments to shared memory.

Template Parameters: MMA_M: MMA M dimension. MMA_N: MMA N dimension. stageN: Stage N dimension (width of each output tile). cta_group: Number of CTAs cooperating (1 or 2). transpose_c: Whether output is transposed.

Fields

  • warp_id (UInt32):
  • lane_id (UInt32):
  • c_row (UInt32):
  • c_col (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Methods

__init__

__init__(warp_id: UInt32, lane_id: UInt32, c_row: UInt32, c_col: UInt32) -> Self

Initialize coordinate calculator.

Args:

  • warp_id (UInt32): Warp ID within the CTA.
  • lane_id (UInt32): Lane ID within the warp.
  • c_row (UInt32): Base row coordinate in global memory.
  • c_col (UInt32): Base column coordinate in global memory.

staged_row

staged_row(self, stage: UInt32, num_stages: UInt32) -> UInt32

Compute the staged row coordinate.

Args:

  • stage (UInt32): Current stage index.
  • num_stages (UInt32): Total number of stages.

Returns:

UInt32: Row coordinate for the current stage.

staged_col

staged_col(self, stage: UInt32, num_stages: UInt32) -> UInt32

Compute the staged column coordinate.

Args:

  • stage (UInt32): Current stage index.
  • num_stages (UInt32): Total number of stages.

Returns:

UInt32: Column coordinate for the current stage.

smem_coord_m

smem_coord_m(self) -> UInt32

Compute shared memory M coordinate for TMA store.

Returns:

UInt32: M coordinate in shared memory tile.

Was this page helpful?