IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

fa4_lse_combine_write

def fa4_lse_combine_write[output_type: DType, //, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], wg_j_offset: Int, iters_per_wg: Int, output_swizzle_mode: TensorMapSwizzle = config.swizzle_mode](local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, final_scale_local: Float32, final_scale_peer: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], own_o_tmem: TMemTile[DType.float32, config.BM, config.padded_ov_depth], peer_o_tmem: TMemTile[DType.float32, config.BM, config.padded_ov_depth], ragged_tma_store: RaggedTMA3DTile[output_type, output_swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if config.fuse_gqa else Int(1)], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)

LSE-combine two TMEM_O fragments and TMA-store a depth-column slice.

1Q-only sibling of fa4_scale_write_output. Each WG handles a disjoint range j in [wg_j_offset, wg_j_offset + iters_per_wg) of swizzle-block columns. For each j, the WG loads both its own and the peer's TMEM_O fragments, combines them in registers via per-row scales (final_scale_local for own, final_scale_peer for peer), writes the combined output to the shared o_smem_arg at the j slot, then TMA-stores that slot to gmem. Both WGs target the same BM Q rows but disjoint depth columns, so smem and gmem regions never overlap.

The caller must have already waited on both pipeline_o0 and pipeline_o1 producer barriers (and issued tcgen05_fence_after()) before invoking this helper, so the TMEM fragments are visible.