Mojo struct
RegisterToGMemWriter
@register_passable(trivial)
struct RegisterToGMemWriter[c_type: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, wgmma_shape: IndexList[3], num_consumer: Int, N: Int, epilogue_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, check_n_bounds: Bool = False]
Writer for transferring accumulator registers directly to global memory.
This writer handles the direct copy from register tiles to global memory tiles, with proper thread distribution and alignment. It supports optional epilogue processing, compute lambda transformations, and bounds checking.
Note: At most one of epilogue_fn or compute_lambda_fn should be set.
Parameters
- c_type (
DType
): Output data type. - dst_layout (
Layout
): Layout of the destination tensor. - dst_address_space (
AddressSpace
): Address space of the destination tensor. - dst_element_layout (
Layout
): Element layout of the destination tensor. - dst_layout_int_type (
DType
): Integer type for destination layout indices. - dst_linear_idx_type (
DType
): Linear index type for destination tensor. - dst_masked (
Bool
): Whether the destination tensor is masked. - dst_alignment (
Int
): Alignment requirement for destination tensor. - wgmma_shape (
IndexList
): Shape of the WGMMA operation [M, N, K]. - num_consumer (
Int
): Number of consumer warp groups. - N (
Int
): Matrix N dimension. - epilogue_fn (
OptionalReg
): Optional epilogue function (mutates value in place). - compute_lambda_fn (
OptionalReg
): Optional compute lambda function (returns new value). - check_n_bounds (
Bool
): Whether to perform bounds checking on N dimension.
Fields
- thread_info (
ThreadInfo
): - dst (
LayoutTensor[c_type, dst_layout, MutableAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]
): - num_m_mmas (
Int
): - tile_coords (
OptionalReg[TileCoordinates]
): - M_bound (
OptionalReg[UInt]
): - N_bound (
OptionalReg[UInt32]
):
Implemented traits
AnyType
,
Copyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
c_frag_size
alias c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // _resolve_warpgroup_size())
num_frag_mats
alias num_frag_mats = ((wgmma_shape.__getitem__[3, DType.int64, Int](1) // 8) * ((wgmma_shape.__getitem__[3, DType.int64, Int](0) // 4) // 8))
num_m_frag_mat
alias num_m_frag_mat = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) // 4) // 8)
num_n_frag_mat
alias num_n_frag_mat = (wgmma_shape.__getitem__[3, DType.int64, Int](1) // 8)
Methods
__init__
__init__(dst: LayoutTensor[c_type, dst_layout, MutableAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], warp_group_thread_idx: UInt, num_m_mmas: Int, tile_coords: OptionalReg[TileCoordinates] = None, M_bound: OptionalReg[UInt] = None, N_bound: OptionalReg[UInt32] = None) -> Self
Initialize the register-to-global-memory writer.
Args:
- dst (
LayoutTensor
): Destination tensor in global memory. - warp_group_thread_idx (
UInt
): Thread index within the warp group. - num_m_mmas (
Int
): Number of MMA tiles in M dimension. - tile_coords (
OptionalReg
): Optional tile coordinates for epilogue processing. - M_bound (
OptionalReg
): Optional maximum valid M coordinate (for epilogue). - N_bound (
OptionalReg
): Optional maximum valid N coordinate (for bounds checking).
write_tile
write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignment], coords: Tuple[UInt, UInt])
Write a single MMA tile from registers to global memory.
Args:
- c_reg_tile (
LayoutTensor
): Register tile containing accumulator values. - coords (
Tuple
): Tile coordinates (row, column) in the destination matrix.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!