Mojo struct
RegisterToGMemWriter
struct RegisterToGMemWriter[c_type: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, wgmma_shape: IndexList[3], num_consumer: Int, N: Int, epilogue_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, check_runtime_bounds: Bool = False, swapAB: Bool = False]
Writer for transferring accumulator registers directly to global memory.
This writer handles the direct copy from register tiles to global memory tiles, with proper thread distribution and alignment. It supports optional epilogue processing, compute lambda transformations, and bounds checking.
Parametersβ
- βc_type (
DType): Output data type. - βdst_layout (
Layout): Layout of the destination tensor. - βdst_address_space (
AddressSpace): Address space of the destination tensor. - βdst_element_layout (
Layout): Element layout of the destination tensor. - βdst_layout_int_type (
DType): Integer type for destination layout indices. - βdst_linear_idx_type (
DType): Linear index type for destination tensor. - βdst_masked (
Bool): Whether the destination tensor is masked. - βdst_alignment (
Int): Alignment requirement for destination tensor. - βwgmma_shape (
IndexList[3]): Shape of the WGMMA operation [M, N, K]. - βnum_consumer (
Int): Number of consumer warp groups. - βN (
Int): Matrix N dimension. - βepilogue_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional epilogue function (mutates value in place). - βcompute_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]): Optional compute lambda function (returns new value). - βcheck_runtime_bounds (
Bool): Whether to perform bounds checking on N dimension. - βswapAB (
Bool): Whether to swap the A and B matrices.
Fieldsβ
- βthread_info (
ThreadInfo): - βdst (
RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].DstType): - βnum_m_mmas (
Int): - βtile_coords (
OptionalReg[TileCoordinates]): - βmax_row (
OptionalReg[UInt32]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegTileWriter,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
c_frag_sizeβ
comptime c_frag_size = ((wgmma_shape[0] * wgmma_shape[1]) // WARPGROUP_SIZE)
DstTypeβ
comptime DstType = LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]
num_frag_matsβ
comptime num_frag_mats = (RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].num_n_frag_mat * RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds, swapAB].num_m_frag_mat)
num_m_frag_matβ
comptime num_m_frag_mat = ((wgmma_shape[0] // 4) // 8)
num_n_frag_matβ
comptime num_n_frag_mat = (wgmma_shape[1] // 8)
Methodsβ
__init__β
__init__(dst: LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], warp_group_thread_idx: Int, num_m_mmas: Int, tile_coords: OptionalReg[TileCoordinates] = None, max_row: OptionalReg[UInt32] = None) -> Self
Initialize the register-to-global-memory writer.
Args:
- βdst (
LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]): Destination tensor in global memory. - βwarp_group_thread_idx (
Int): Thread index within the warp group. - βnum_m_mmas (
Int): Number of MMA tiles in M dimension. - βtile_coords (
OptionalReg[TileCoordinates]): Optional tile coordinates for epilogue processing. - βmax_row (
OptionalReg[UInt32]): Optional maximum valid M coordinate (for epilogue).
write_tileβ
write_tile(self, c_reg_tile: LayoutTensor[MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=c_reg_tile.element_layout, layout_int_type=c_reg_tile.layout_int_type, linear_idx_type=c_reg_tile.linear_idx_type, masked=c_reg_tile.masked, alignment=c_reg_tile.alignment], coords: Tuple[Int, Int])
Write a single MMA tile from registers to global memory.
Args:
- βc_reg_tile (
LayoutTensor[MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=c_reg_tile.element_layout, layout_int_type=c_reg_tile.layout_int_type, linear_idx_type=c_reg_tile.linear_idx_type, masked=c_reg_tile.masked, alignment=c_reg_tile.alignment]): Register tile containing accumulator values. - βcoords (
Tuple[Int, Int]): Tile coordinates (row, column) in the destination matrix.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!