Mojo function
fa4_scale_write_output
fa4_scale_write_output[qkv_type: DType, output_type: DType, config: FA4Config](local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, inv_row_sum: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], o_tmem_arg: TMemTile[DType.float32, (config // 2), config.padded_depth], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), config.depth], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!