Skip to main content

Mojo struct

RegisterToGMemWriter

@register_passable(trivial) struct RegisterToGMemWriter[c_type: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, wgmma_shape: IndexList[3], num_consumer: Int, N: Int, epilogue_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, check_runtime_bounds: Bool = False]

Writer for transferring accumulator registers directly to global memory.

This writer handles the direct copy from register tiles to global memory tiles, with proper thread distribution and alignment. It supports optional epilogue processing, compute lambda transformations, and bounds checking.

Note: At most one of epilogue_fn or compute_lambda_fn should be set.

Parameters

  • c_type (DType): Output data type.
  • dst_layout (Layout): Layout of the destination tensor.
  • dst_address_space (AddressSpace): Address space of the destination tensor.
  • dst_element_layout (Layout): Element layout of the destination tensor.
  • dst_layout_int_type (DType): Integer type for destination layout indices.
  • dst_linear_idx_type (DType): Linear index type for destination tensor.
  • dst_masked (Bool): Whether the destination tensor is masked.
  • dst_alignment (Int): Alignment requirement for destination tensor.
  • wgmma_shape (IndexList): Shape of the WGMMA operation [M, N, K].
  • num_consumer (Int): Number of consumer warp groups.
  • N (Int): Matrix N dimension.
  • epilogue_fn (OptionalReg): Optional epilogue function (mutates value in place).
  • compute_lambda_fn (OptionalReg): Optional compute lambda function (returns new value).
  • check_runtime_bounds (Bool): Whether to perform bounds checking on N dimension.

Fields

  • thread_info (ThreadInfo):
  • dst (RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds].DstType):
  • num_m_mmas (Int):
  • tile_coords (OptionalReg[TileCoordinates]):
  • max_row (OptionalReg[UInt32]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, RegTileWriter, UnknownDestructibility

Aliases

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

c_frag_size

comptime c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // WARPGROUP_SIZE)

DstType

comptime DstType = LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]

num_frag_mats

comptime num_frag_mats = (RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds].num_n_frag_mat * RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds].num_m_frag_mat)

num_m_frag_mat

comptime num_m_frag_mat = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) // 4) // 8)

num_n_frag_mat

comptime num_n_frag_mat = (wgmma_shape.__getitem__[3, DType.int64, Int](1) // 8)

Methods

__init__

__init__(dst: LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], warp_group_thread_idx: UInt, num_m_mmas: Int, tile_coords: OptionalReg[TileCoordinates] = None, max_row: OptionalReg[UInt32] = None) -> Self

Initialize the register-to-global-memory writer.

Args:

  • dst (LayoutTensor): Destination tensor in global memory.
  • warp_group_thread_idx (UInt): Thread index within the warp group.
  • num_m_mmas (Int): Number of MMA tiles in M dimension.
  • tile_coords (OptionalReg): Optional tile coordinates for epilogue processing.
  • max_row (OptionalReg): Optional maximum valid M coordinate (for epilogue).

write_tile

write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])

Write a single MMA tile from registers to global memory.

Args:

  • c_reg_tile (LayoutTensor): Register tile containing accumulator values.
  • coords (Tuple): Tile coordinates (row, column) in the destination matrix.

Was this page helpful?