Mojo struct
StMatrixWriter
@register_passable(trivial)
struct StMatrixWriter[c_type: DType, c_smem_layout: Layout, stageN: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]
Write register fragments to shared memory using st.matrix.
Handles the complex swizzling and addressing required for efficient shared memory writes from WGMMA accumulator fragments.
Template Parameters: c_type: Output data type. c_smem_layout: Shared memory tile layout. stageN: Stage width in elements. c_swizzle: TMA swizzle mode. transpose_c: Whether output is transposed.
Fields
- swizzle (
Swizzle): - lane_id (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Config
comptime Config = StMatrixConfig[c_type, stageN, c_swizzle, transpose_c]
shape0
comptime shape0 = c_smem_layout.shape[1].value() if (not transpose_c._mlir_value) else c_smem_layout.shape[0].value()
stride0
comptime stride0 = c_smem_layout.stride[0].value()
stride1
comptime stride1 = c_smem_layout.stride[1].value()
stsmx_tile_offset
comptime stsmx_tile_offset = (StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].stride0 if transpose_c else StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].stride1 * StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].Config.stsmx_row_size)
Methods
__init__
__init__(lane_id: UInt32) -> Self
Initialize the st.matrix writer.
Args:
- lane_id (
UInt32): Lane ID within the warp.
compute_lane_offset
compute_lane_offset(self) -> UInt32
Compute the base lane offset for st.matrix.
Returns:
UInt32: Lane offset in shared memory.
write_fragment
write_fragment[frag_size: Int](self, frag: SIMD[dtype, frag_size], dst: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_offset: UInt32 = 0)
Write a fragment to shared memory using st.matrix.
Args:
- frag (
SIMD): Source fragment (typically from TMEM load). - dst (
LayoutTensor): Destination shared memory tile. - warp_offset (
UInt32): Additional warp-based offset for transpose mode.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!