Skip to main content

Mojo struct

FragmentToSMemWriter

@register_passable(trivial) struct FragmentToSMemWriter[c_type: DType, c_tile_layout: Layout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, half_tile: Bool, WG_BM: Int, WG_BN: Int, sub_wg_id: Int]

Writes WGMMA accumulator results from registers to shared memory using st.matrix.

Stores 16-byte fragments with swizzling to avoid bank conflicts. Sub-warp groups divide N-dimension work, each handling a portion of WG_BN output tiles.

Parameters

  • c_type (DType): Output data type (must be bfloat16 for st.matrix).
  • c_tile_layout (Layout): Layout of the entire shared memory region.
  • tile_n_size (Int): Width of each output tile (typically TMA_BN).
  • num_m_mmas (Int): Number of MMA operations in M dimension.
  • num_consumer (Int): Number of consumer warp groups.
  • half_tile (Bool): Special mode for handling partial tiles.
  • WG_BM (Int): Warp group tile height.
  • WG_BN (Int): Warp group tile width.
  • sub_wg_id (Int): Which portion of WG_BN this instance handles.

Fields

  • c_tile (LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]):
  • warp_group_thread_idx (UInt):
  • local_warp_group_idx (UInt):
  • st_matrix_rt_layout (RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, RegTileWriter, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

st_matrix_layout

alias st_matrix_layout = Layout.row_major(WG_BM, tile_n_size)

st_matrix_rt_layout_type

alias st_matrix_rt_layout_type = RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]

st_matrix_swizzle

alias st_matrix_swizzle = make_ldmatrix_swizzle[c_type, tile_n_size, log2_floor((16 // size_of[c_type]()))]()

Methods

__init__

__init__(c_tile: LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_group_thread_idx: UInt, local_warp_group_idx: UInt) -> Self

Initialize the fragment writer.

Args:

  • c_tile (LayoutTensor): Shared memory tile to write to.
  • warp_group_thread_idx (UInt): Thread index within the warp group.
  • local_warp_group_idx (UInt): Sub-warp group index (divides N-dimension work).

write_tile

write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])

Write accumulator tile from registers to shared memory.

Args:

  • c_reg_tile (LayoutTensor): Register tile containing MMA results.
  • coords (Tuple): Tile position (row_idx, col_idx) in output.

Was this page helpful?