IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mla_decode_combine

MLA Decode Split-K Combine Kernel for SM100 (B200).

This kernel combines partial outputs from split-K attention computation. Each split computes attention over a portion of the KV cache. The combine kernel merges these partial results using LSE (Log-Sum-Exp) for numerical stability.

Algorithm:

  1. Load partial LSE values for all splits
  2. Compute global LSE: log2(sum(exp2(lse_i - max_lse))) + max_lse
  3. Compute per-split scale factors: scale_i = exp2(lse_i - global_lse)
  4. Weighted sum: output = sum(scale_i * partial_output_i)

Structs​

Functions​