Mojo struct
FragmentToSMemWriter
@register_passable(trivial)
struct FragmentToSMemWriter[c_type: DType, c_tile_layout: Layout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, half_tile: Bool, WG_BM: Int, WG_BN: Int, sub_wg_id: Int]
Writes WGMMA accumulator results from registers to shared memory using st.matrix.
Stores 16-byte fragments with swizzling to avoid bank conflicts. Sub-warp groups divide N-dimension work, each handling a portion of WG_BN output tiles.
Parameters
- c_type (
DType): Output data type (must be bfloat16 for st.matrix). - c_tile_layout (
Layout): Layout of the entire shared memory region. - tile_n_size (
Int): Width of each output tile (typically TMA_BN). - num_m_mmas (
Int): Number of MMA operations in M dimension. - num_consumer (
Int): Number of consumer warp groups. - half_tile (
Bool): Special mode for handling partial tiles. - WG_BM (
Int): Warp group tile height. - WG_BN (
Int): Warp group tile width. - sub_wg_id (
Int): Which portion of WG_BN this instance handles.
Fields
- c_tile (
LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]): - warp_group_thread_idx (
UInt): - local_warp_group_idx (
UInt): - st_matrix_rt_layout (
RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
RegTileWriter,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
st_matrix_layout
alias st_matrix_layout = Layout.row_major(WG_BM, tile_n_size)
st_matrix_rt_layout_type
alias st_matrix_rt_layout_type = RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]
st_matrix_swizzle
alias st_matrix_swizzle = make_ldmatrix_swizzle[c_type, tile_n_size, log2_floor((16 // size_of[c_type]()))]()
Methods
__init__
__init__(c_tile: LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_group_thread_idx: UInt, local_warp_group_idx: UInt) -> Self
Initialize the fragment writer.
Args:
- c_tile (
LayoutTensor): Shared memory tile to write to. - warp_group_thread_idx (
UInt): Thread index within the warp group. - local_warp_group_idx (
UInt): Sub-warp group index (divides N-dimension work).
write_tile
write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])
Write accumulator tile from registers to shared memory.
Args:
- c_reg_tile (
LayoutTensor): Register tile containing MMA results. - coords (
Tuple): Tile position (row_idx, col_idx) in output.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!