Mojo struct
StMatrixCoords
@register_passable(trivial)
struct StMatrixCoords[MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, transpose_c: Bool]
Compute coordinates for st.matrix operations.
Encapsulates the complex coordinate calculations needed for storing accumulator fragments to shared memory.
Template Parameters: MMA_M: MMA M dimension. MMA_N: MMA N dimension. stageN: Stage N dimension (width of each output tile). cta_group: Number of CTAs cooperating (1 or 2). transpose_c: Whether output is transposed.
Fields
- warp_id (
UInt32): - lane_id (
UInt32): - c_row (
UInt32): - c_col (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Methods
__init__
__init__(warp_id: UInt32, lane_id: UInt32, c_row: UInt32, c_col: UInt32) -> Self
Initialize coordinate calculator.
Args:
staged_row
staged_row(self, stage: UInt32, num_stages: UInt32) -> UInt32
Compute the staged row coordinate.
Args:
Returns:
UInt32: Row coordinate for the current stage.
staged_col
staged_col(self, stage: UInt32, num_stages: UInt32) -> UInt32
Compute the staged column coordinate.
Args:
Returns:
UInt32: Column coordinate for the current stage.
smem_coord_m
smem_coord_m(self) -> UInt32
Compute shared memory M coordinate for TMA store.
Returns:
UInt32: M coordinate in shared memory tile.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!