Mojo function
fa4_scale_write_output
fa4_scale_write_output[output_type: DType, //, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], output_swizzle_mode: TensorMapSwizzle = config.swizzle_mode](local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, inv_row_sum: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], o_tmem_arg: TMemTile[DType.float32, (config // 2), config.padded_ov_depth], ragged_tma_store: RaggedTMA3DTile[output_type, output_swizzle_mode, (config // 2), config.ov_depth], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!