Skip to main content

Mojo module

mla_decode_sm100_combine

MLA Decode Split-K Combine Kernel for SM100 (B200).

This kernel combines partial outputs from split-K attention computation. Each split computes attention over a portion of the KV cache. The combine kernel merges these partial results using LSE (Log-Sum-Exp) for numerical stability.

Algorithm:

  1. Load partial LSE values for all splits
  2. Compute global LSE: log2(sum(exp2(lse_i - max_lse))) + max_lse
  3. Compute per-split scale factors: scale_i = exp2(lse_i - global_lse)
  4. Weighted sum: output = sum(scale_i * partial_output_i)

Structs

Functions

Was this page helpful?