Mojo function
fa4_scale_write_output
fa4_scale_write_output[qkv_type: DType, output_type: DType, config: FA4Config, output_swizzle_mode: TensorMapSwizzle = config.swizzle_mode, kv_depth: Int = config.depth, half_bm: Int = (config // 2), tmem_kv_depth: Int = config.padded_depth](local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, inv_row_sum: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], o_tmem_arg: TMemTile[DType.float32, half_bm, tmem_kv_depth], ragged_tma_store: RaggedTMA3DTile[output_type, output_swizzle_mode, half_bm, kv_depth], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!