Mojo function
depth512_scale_write_output
depth512_scale_write_output[output_type: DType, qkv_dtype: DType, config: Depth512SM100Config[qkv_dtype]](tid: UInt32, m_row: UInt32, is_lower: Bool, inv_row_sum: Float32, smem: Depth512AttentionSMem[config], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=config.BM, BN=config.ov_depth, group=config.group if config.fuse_gqa else 1], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)
Read O from TMEM, scale by inv_row_sum, write to SMEM, TMA store.
split_o (d512): Two phases (O_lo, O_hi). Each thread processes ov_depth/4 physical TMEM cols per phase, with is_lower determining output col base. !split_o (d256): Single phase. Each thread processes MMA_M*ov_depth/256 physical TMEM cols. All threads write to col base 0.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!