Skip to main content

Mojo struct

RegisterToGMemWriter

@register_passable(trivial) struct RegisterToGMemWriter[c_type: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, wgmma_shape: IndexList[3], num_consumer: Int, N: Int, epilogue_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, check_n_bounds: Bool = False]

Writer for transferring accumulator registers directly to global memory.

This writer handles the direct copy from register tiles to global memory tiles, with proper thread distribution and alignment. It supports optional epilogue processing, compute lambda transformations, and bounds checking.

Note: At most one of epilogue_fn or compute_lambda_fn should be set.

Parameters

  • c_type (DType): Output data type.
  • dst_layout (Layout): Layout of the destination tensor.
  • dst_address_space (AddressSpace): Address space of the destination tensor.
  • dst_element_layout (Layout): Element layout of the destination tensor.
  • dst_layout_int_type (DType): Integer type for destination layout indices.
  • dst_linear_idx_type (DType): Linear index type for destination tensor.
  • dst_masked (Bool): Whether the destination tensor is masked.
  • dst_alignment (Int): Alignment requirement for destination tensor.
  • wgmma_shape (IndexList): Shape of the WGMMA operation [M, N, K].
  • num_consumer (Int): Number of consumer warp groups.
  • N (Int): Matrix N dimension.
  • epilogue_fn (OptionalReg): Optional epilogue function (mutates value in place).
  • compute_lambda_fn (OptionalReg): Optional compute lambda function (returns new value).
  • check_n_bounds (Bool): Whether to perform bounds checking on N dimension.

Fields

  • thread_info (ThreadInfo):
  • dst (LayoutTensor[c_type, dst_layout, MutableAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]):
  • num_m_mmas (Int):
  • tile_coords (OptionalReg[TileCoordinates]):
  • M_bound (OptionalReg[UInt]):
  • N_bound (OptionalReg[UInt32]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

c_frag_size

alias c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // _resolve_warpgroup_size())

num_frag_mats

alias num_frag_mats = ((wgmma_shape.__getitem__[3, DType.int64, Int](1) // 8) * ((wgmma_shape.__getitem__[3, DType.int64, Int](0) // 4) // 8))

num_m_frag_mat

alias num_m_frag_mat = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) // 4) // 8)

num_n_frag_mat

alias num_n_frag_mat = (wgmma_shape.__getitem__[3, DType.int64, Int](1) // 8)

Methods

__init__

__init__(dst: LayoutTensor[c_type, dst_layout, MutableAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], warp_group_thread_idx: UInt, num_m_mmas: Int, tile_coords: OptionalReg[TileCoordinates] = None, M_bound: OptionalReg[UInt] = None, N_bound: OptionalReg[UInt32] = None) -> Self

Initialize the register-to-global-memory writer.

Args:

  • dst (LayoutTensor): Destination tensor in global memory.
  • warp_group_thread_idx (UInt): Thread index within the warp group.
  • num_m_mmas (Int): Number of MMA tiles in M dimension.
  • tile_coords (OptionalReg): Optional tile coordinates for epilogue processing.
  • M_bound (OptionalReg): Optional maximum valid M coordinate (for epilogue).
  • N_bound (OptionalReg): Optional maximum valid N coordinate (for bounds checking).

write_tile

write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignment], coords: Tuple[UInt, UInt])

Write a single MMA tile from registers to global memory.

Args:

  • c_reg_tile (LayoutTensor): Register tile containing accumulator values.
  • coords (Tuple): Tile coordinates (row, column) in the destination matrix.

Was this page helpful?