Skip to main content

Mojo struct

RegisterToGMemWriter

struct RegisterToGMemWriter[c_type: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, wgmma_shape: IndexList[3], num_consumer: Int, N: Int, epilogue_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, check_runtime_bounds: Bool = False, swapAB: Bool = False]

Writer for transferring accumulator registers directly to global memory.

This writer handles the direct copy from register tiles to global memory tiles, with proper thread distribution and alignment. It supports optional epilogue processing, compute lambda transformations, and bounds checking.

Parameters​

Fields​

  • ​thread_info (ThreadInfo):
  • ​dst (RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].DstType):
  • ​num_m_mmas (Int):
  • ​tile_coords (OptionalReg[TileCoordinates]):
  • ​max_row (OptionalReg[UInt32]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegTileWriter, RegisterPassable, TrivialRegisterPassable

comptime members​

c_frag_size​

comptime c_frag_size = ((wgmma_shape[0] * wgmma_shape[1]) // WARPGROUP_SIZE)

DstType​

comptime DstType = LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]

num_frag_mats​

comptime num_frag_mats = (RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].num_n_frag_mat * RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].num_m_frag_mat)

num_m_frag_mat​

comptime num_m_frag_mat = ((wgmma_shape[0] // 4) // 8)

num_n_frag_mat​

comptime num_n_frag_mat = (wgmma_shape[1] // 8)

Methods​

__init__​

__init__(dst: LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], warp_group_thread_idx: Int, num_m_mmas: Int, tile_coords: OptionalReg[TileCoordinates] = None, max_row: OptionalReg[UInt32] = None) -> Self

Initialize the register-to-global-memory writer.

Args:

write_tile​

write_tile(self, c_reg_tile: LayoutTensor[MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=c_reg_tile.element_layout, layout_int_type=c_reg_tile.layout_int_type, linear_idx_type=c_reg_tile.linear_idx_type, masked=c_reg_tile.masked, alignment=c_reg_tile.alignment], coords: Tuple[Int, Int])

Write a single MMA tile from registers to global memory.

Args: