Mojo struct
FragmentToSMemWriter
struct FragmentToSMemWriter[c_type: DType, c_tile_layout: Layout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, half_tile: Bool, WG_BM: Int, WG_BN: Int, sub_wg_id: Int, swapAB: Bool = False]
Writes WGMMA accumulator results from registers to shared memory using st.matrix.
Stores 16-byte fragments with swizzling to avoid bank conflicts. Sub-warp groups divide N-dimension work, each handling a portion of WG_BN output tiles.
Parametersβ
- βc_type (
DType): Output data type (must be bfloat16 for st.matrix). - βc_tile_layout (
Layout): Layout of the entire shared memory region. - βtile_n_size (
Int): Width of each output tile (typically TMA_BN). - βnum_m_mmas (
Int): Number of MMA operations in M dimension. - βnum_consumer (
Int): Number of consumer warp groups. - βhalf_tile (
Bool): Special mode for handling partial tiles. - βWG_BM (
Int): Warp group tile height. - βWG_BN (
Int): Warp group tile width. - βsub_wg_id (
Int): Which portion of WG_BN this instance handles. - βswapAB (
Bool): Whether to swap the A and B matrices.
Fieldsβ
- βc_tile (
LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]): - βwarp_group_thread_idx (
Int): - βlocal_warp_group_idx (
Int): - βst_matrix_rt_layout (
FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_rt_layout_type):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegTileWriter,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
st_matrix_layoutβ
comptime st_matrix_layout = Layout.row_major(WG_BM, tile_n_size) if not swapAB else Layout.row_major(tile_n_size, WG_BN)
st_matrix_layout_regularβ
comptime st_matrix_layout_regular = st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer]()
st_matrix_layout_transposeβ
comptime st_matrix_layout_transpose = st_matrix_m_layout[c_type, tile_n_size, num_m_mmas, num_consumer]()
st_matrix_rt_layout_typeβ
comptime st_matrix_rt_layout_type = RuntimeLayout[FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_layout_regular if not swapAB else FragmentToSMemWriter[tile_n_size, num_m_mmas, num_consumer, half_tile, WG_BM, WG_BN, sub_wg_id, swapAB].st_matrix_layout_transpose, element_type=DType.int32, linear_idx_type=DType.int32]
st_matrix_swizzleβ
comptime st_matrix_swizzle = make_ldmatrix_swizzle[c_type, tile_n_size if not swapAB else WG_BN, log2_floor((16 // size_of[c_type]()))]()
Methodsβ
__init__β
__init__(c_tile: LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_group_thread_idx: Int, local_warp_group_idx: Int) -> Self
Initialize the fragment writer.
Args:
- βc_tile (
LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]): Shared memory tile to write to. - βwarp_group_thread_idx (
Int): Thread index within the warp group. - βlocal_warp_group_idx (
Int): Sub-warp group index (divides N-dimension work).
write_tileβ
write_tile(self, c_reg_tile: LayoutTensor[MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=c_reg_tile.element_layout, layout_int_type=c_reg_tile.layout_int_type, linear_idx_type=c_reg_tile.linear_idx_type, masked=c_reg_tile.masked, alignment=c_reg_tile.alignment], coords: Tuple[Int, Int])
Write accumulator tile from registers to shared memory.
Args:
- βc_reg_tile (
LayoutTensor[MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=c_reg_tile.element_layout, layout_int_type=c_reg_tile.layout_int_type, linear_idx_type=c_reg_tile.linear_idx_type, masked=c_reg_tile.masked, alignment=c_reg_tile.alignment]): Register tile containing MMA results. - βcoords (
Tuple[Int, Int]): Tile position (row_idx, col_idx) in output.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!