IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

FragmentToSMemWriter

struct FragmentToSMemWriter[c_type: DType, c_tile_layout: TensorLayout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, half_tile: Bool, WG_BM: Int, WG_BN: Int, sub_wg_id: Int, swapAB: Bool = False]

Writes WGMMA accumulator results from registers to shared memory using st.matrix.

Stores 16-byte fragments with swizzling to avoid bank conflicts. Sub-warp groups divide N-dimension work, each handling a portion of WG_BN output tiles.

Parameters​

  • ​c_type (DType): Output data type (must be bfloat16 for st.matrix).
  • ​c_tile_layout (TensorLayout): Layout of the entire shared memory region.
  • ​tile_n_size (Int): Width of each output tile (typically TMA_BN).
  • ​num_m_mmas (Int): Number of MMA operations in M dimension.
  • ​num_consumer (Int): Number of consumer warp groups.
  • ​half_tile (Bool): Special mode for handling partial tiles.
  • ​WG_BM (Int): Warp group tile height.
  • ​WG_BN (Int): Warp group tile width.
  • ​sub_wg_id (Int): Which portion of WG_BN this instance handles.
  • ​swapAB (Bool): Whether to swap the A and B matrices.

Fields​

  • ​c_tile (TileTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • ​warp_group_thread_idx (Int):
  • ​local_warp_group_idx (Int):
  • ​st_matrix_rt_layout (FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_rt_layout_type):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegTileWriter, RegisterPassable, TrivialRegisterPassable

comptime members​

st_matrix_layout_regular​

comptime st_matrix_layout_regular = st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer]()

st_matrix_layout_transpose​

comptime st_matrix_layout_transpose = st_matrix_m_layout[c_type, tile_n_size, num_m_mmas, num_consumer]()

st_matrix_rt_layout_type​

comptime st_matrix_rt_layout_type = RuntimeLayout[FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_layout_regular if not swapAB else FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_layout_transpose, element_type=DType.int32, linear_idx_type=DType.int32]

st_matrix_swizzle​

comptime st_matrix_swizzle = make_ldmatrix_swizzle[c_type, tile_n_size if not swapAB else WG_BN, log2_floor((Int(16) // size_of[c_type]()))]()

st_matrix_tile_layout_regular​

comptime st_matrix_tile_layout_regular = row_major[WG_BM, tile_n_size]()

st_matrix_tile_layout_swapAB​

comptime st_matrix_tile_layout_swapAB = row_major[tile_n_size, WG_BN]()

Methods​

__init__​

def __init__(c_tile: TileTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], warp_group_thread_idx: Int, local_warp_group_idx: Int) -> Self

Initialize the fragment writer.

Args:

write_tile​

def write_tile(self, c_reg_tile: LayoutTensor[MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=c_reg_tile.element_layout, layout_int_type=c_reg_tile.layout_int_type, linear_idx_type=c_reg_tile.linear_idx_type, masked=c_reg_tile.masked, alignment=c_reg_tile.alignment], coords: Tuple[Int, Int])

Write accumulator tile from registers to shared memory.

Args: