Mojo struct
FragmentToSMemWriter
@register_passable(trivial)
struct FragmentToSMemWriter[c_type: DType, c_tile_layout: Layout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, use_x2_for_last_iter: Bool, WG_BM: Int, WG_BN: Int, sub_wg_bn_id: Int]
Writer for storing accumulator fragments from registers to shared memory.
Uses st.matrix instructions for efficient bf16 storage with proper swizzling for bank conflict avoidance.
Parameters
- c_type (
DType
): Output data type (e.g., bfloat16). - c_tile_layout (
Layout
): Shared memory tile layout. - tile_n_size (
Int
): Size of each tile in N dimension. - num_m_mmas (
Int
): Number of MMA tiles in M dimension. - num_consumer (
Int
): Number of consumer warp groups. - use_x2_for_last_iter (
Bool
): Whether to use x2 mode for the last iteration. - WG_BM (
Int
): Warp group M dimension. - WG_BN (
Int
): Warp group N dimension. - sub_wg_bn_id (
Int
): Sub warp group ID in N dimension.
Fields
- c_tile (
LayoutTensor[c_type, c_tile_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128]
): - warp_group_thread_idx (
UInt
): - local_warp_group_idx (
UInt
): - st_matrix_swizzle (
Swizzle
): - st_matrix_rt_layout (
RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]
):
Implemented traits
AnyType
,
Copyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
Methods
__init__
__init__(c_tile: LayoutTensor[c_type, c_tile_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], warp_group_thread_idx: UInt, local_warp_group_idx: UInt, st_matrix_swizzle: Swizzle, st_matrix_rt_layout: RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]) -> Self
Initialize the fragment writer.
Args:
- c_tile (
LayoutTensor
): Shared memory tile to write to. - warp_group_thread_idx (
UInt
): Thread index within the warp group. - local_warp_group_idx (
UInt
): Warp group index within the consumer groups. - st_matrix_swizzle (
Swizzle
): Swizzle pattern for bank conflict avoidance. - st_matrix_rt_layout (
RuntimeLayout
): Runtime layout for st.matrix operations.
write_tile
write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignment], coords: Tuple[UInt, UInt])
Write accumulator fragments to shared memory at specified tile coordinates.
Args:
- c_reg_tile (
LayoutTensor
): Source register tile containing accumulator values. - coords (
Tuple
): Tile coordinates (row_tile_idx, col_tile_idx) where to write.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!