For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
fa4_lse_combine_write
def fa4_lse_combine_write[output_type: DType, //, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], wg_j_offset: Int, iters_per_wg: Int, output_swizzle_mode: TensorMapSwizzle = config.swizzle_mode](local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, final_scale_local: Float32, final_scale_peer: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], own_o_tmem: TMemTile[DType.float32, config.BM, config.padded_ov_depth], peer_o_tmem: TMemTile[DType.float32, config.BM, config.padded_ov_depth], ragged_tma_store: RaggedTMA3DTile[output_type, output_swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if config.fuse_gqa else Int(1)], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)
LSE-combine two TMEM_O fragments and TMA-store a depth-column slice.
1Q-only sibling of fa4_scale_write_output. Each WG handles a disjoint
range j in [wg_j_offset, wg_j_offset + iters_per_wg) of swizzle-block
columns. For each j, the WG loads both its own and the peer's TMEM_O
fragments, combines them in registers via per-row scales
(final_scale_local for own, final_scale_peer for peer), writes the
combined output to the shared o_smem_arg at the j slot, then
TMA-stores that slot to gmem. Both WGs target the same BM Q rows but
disjoint depth columns, so smem and gmem regions never overlap.
The caller must have already waited on both pipeline_o0 and
pipeline_o1 producer barriers (and issued tcgen05_fence_after())
before invoking this helper, so the TMEM fragments are visible.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!