Skip to main content

Mojo struct

StMatrixWriter

@register_passable(trivial) struct StMatrixWriter[c_type: DType, c_smem_layout: Layout, stageN: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]

Write register fragments to shared memory using st.matrix.

Handles the complex swizzling and addressing required for efficient shared memory writes from WGMMA accumulator fragments.

Template Parameters: c_type: Output data type. c_smem_layout: Shared memory tile layout. stageN: Stage width in elements. c_swizzle: TMA swizzle mode. transpose_c: Whether output is transposed.

Fields

  • swizzle (Swizzle):
  • lane_id (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Config

comptime Config = StMatrixConfig[c_type, stageN, c_swizzle, transpose_c]

shape0

comptime shape0 = c_smem_layout.shape[1].value() if (not transpose_c._mlir_value) else c_smem_layout.shape[0].value()

stride0

comptime stride0 = c_smem_layout.stride[0].value()

stride1

comptime stride1 = c_smem_layout.stride[1].value()

stsmx_tile_offset

comptime stsmx_tile_offset = (StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].stride0 if transpose_c else StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].stride1 * StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].Config.stsmx_row_size)

Methods

__init__

__init__(lane_id: UInt32) -> Self

Initialize the st.matrix writer.

Args:

  • lane_id (UInt32): Lane ID within the warp.

compute_lane_offset

compute_lane_offset(self) -> UInt32

Compute the base lane offset for st.matrix.

Returns:

UInt32: Lane offset in shared memory.

write_fragment

write_fragment[frag_size: Int](self, frag: SIMD[dtype, frag_size], dst: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_offset: UInt32 = 0)

Write a fragment to shared memory using st.matrix.

Args:

  • frag (SIMD): Source fragment (typically from TMEM load).
  • dst (LayoutTensor): Destination shared memory tile.
  • warp_offset (UInt32): Additional warp-based offset for transpose mode.

Was this page helpful?