Mojo function
write_output_fragments
write_output_fragments[c_type: DType, c_frag_size: Int, MMA_M: Int, MMA_N: Int, output_thread_layout: Layout, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c_reg_fragment: LayoutTensor[c_reg_fragment.dtype, c_reg_fragment.layout, c_reg_fragment.origin, address_space=c_reg_fragment.address_space, element_layout=c_reg_fragment.element_layout, layout_int_type=c_reg_fragment.layout_int_type, linear_idx_type=c_reg_fragment.linear_idx_type, masked=c_reg_fragment.masked, alignment=c_reg_fragment.alignment], c_gmem_fragment: LayoutTensor[c_gmem_fragment.dtype, c_gmem_fragment.layout, c_gmem_fragment.origin, address_space=c_gmem_fragment.address_space, element_layout=c_gmem_fragment.element_layout, layout_int_type=c_gmem_fragment.layout_int_type, linear_idx_type=c_gmem_fragment.linear_idx_type, masked=c_gmem_fragment.masked, alignment=c_gmem_fragment.alignment], warp_tile_m: Int, warp_tile_n: Int, M: Int, N: Int)
Write output fragments from registers to global memory with optional elementwise operations.
Parameters:
- c_type (
DType): Data type for the output matrix C. - c_frag_size (
Int): Size of each output fragment. - MMA_M (
Int): Matrix multiply instruction M dimension. - MMA_N (
Int): Matrix multiply instruction N dimension. - output_thread_layout (
Layout): Thread layout for output operations. - elementwise_lambda_fn (
Optional): Optional elementwise operation to apply.
Args:
- c_reg_fragment (
LayoutTensor): Register fragments containing computation results. - c_gmem_fragment (
LayoutTensor): Global memory fragment for output. - warp_tile_m (
Int): M coordinate of the warp tile. - warp_tile_n (
Int): N coordinate of the warp tile. - M (
Int): Total M dimension of the output matrix. - N (
Int): Total N dimension of the output matrix.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!