Mojo module
mla_decode_sm100_combine
MLA Decode Split-K Combine Kernel for SM100 (B200).
This kernel combines partial outputs from split-K attention computation. Each split computes attention over a portion of the KV cache. The combine kernel merges these partial results using LSE (Log-Sum-Exp) for numerical stability.
Algorithm:
- Load partial LSE values for all splits
- Compute global LSE: log2(sum(exp2(lse_i - max_lse))) + max_lse
- Compute per-split scale factors: scale_i = exp2(lse_i - global_lse)
- Weighted sum: output = sum(scale_i * partial_output_i)
Structs
Functions
-
launch_mla_combine_kernel: -
mla_combine_kernel: -
mla_decode_combine_partial_outputs: -
warp_reduce_max: Warp-level max reduction using butterfly pattern. -
warp_reduce_sum: Warp-level sum reduction using butterfly pattern.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!