Skip to main content

Mojo struct

FragmentToSMemWriter

@register_passable(trivial) struct FragmentToSMemWriter[c_type: DType, c_tile_layout: Layout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, use_x2_for_last_iter: Bool, WG_BM: Int, WG_BN: Int, sub_wg_bn_id: Int]

Writer for storing accumulator fragments from registers to shared memory.

Uses st.matrix instructions for efficient bf16 storage with proper swizzling for bank conflict avoidance.

Parameters

  • c_type (DType): Output data type (e.g., bfloat16).
  • c_tile_layout (Layout): Shared memory tile layout.
  • tile_n_size (Int): Size of each tile in N dimension.
  • num_m_mmas (Int): Number of MMA tiles in M dimension.
  • num_consumer (Int): Number of consumer warp groups.
  • use_x2_for_last_iter (Bool): Whether to use x2 mode for the last iteration.
  • WG_BM (Int): Warp group M dimension.
  • WG_BN (Int): Warp group N dimension.
  • sub_wg_bn_id (Int): Sub warp group ID in N dimension.

Fields

  • c_tile (LayoutTensor[c_type, c_tile_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128]):
  • warp_group_thread_idx (UInt):
  • local_warp_group_idx (UInt):
  • st_matrix_swizzle (Swizzle):
  • st_matrix_rt_layout (RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

Methods

__init__

__init__(c_tile: LayoutTensor[c_type, c_tile_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], warp_group_thread_idx: UInt, local_warp_group_idx: UInt, st_matrix_swizzle: Swizzle, st_matrix_rt_layout: RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]) -> Self

Initialize the fragment writer.

Args:

  • c_tile (LayoutTensor): Shared memory tile to write to.
  • warp_group_thread_idx (UInt): Thread index within the warp group.
  • local_warp_group_idx (UInt): Warp group index within the consumer groups.
  • st_matrix_swizzle (Swizzle): Swizzle pattern for bank conflict avoidance.
  • st_matrix_rt_layout (RuntimeLayout): Runtime layout for st.matrix operations.

write_tile

write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignment], coords: Tuple[UInt, UInt])

Write accumulator fragments to shared memory at specified tile coordinates.

Args:

  • c_reg_tile (LayoutTensor): Source register tile containing accumulator values.
  • coords (Tuple): Tile coordinates (row_tile_idx, col_tile_idx) where to write.

Was this page helpful?