Mojo function
scale_write_output
scale_write_output[BM: Int, BN: Int, depth: Int, padded_depth: Int, q_num_heads: Int, group: Int, decoding: Bool, accum_type: DType, output_type: DType, //, config: FA4Config](local_row: UInt32, inv_row_sum: Scalar[accum_type], o_ptr_arg: UnsafePointer[Scalar[output_type]], o_smem: UnsafePointer[Scalar[output_type], address_space=AddressSpace(3)], o_tmem: TMemTile[accum_type, (config // 2), config.padded_depth], local_warp_group_idx: UInt32, position: MHAPosition[BM, BN, depth, padded_depth, q_num_heads, group, decoding], consumer_mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!