IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

naive_fa_decode

Apple (Metal) split-K naive flash-attention DECODE kernels.

Apple silicon GPU (Metal), decode-only (one query token per sequence), paged KV cache via the MHAOperand contract, BF16 storage / FP32 accumulation. Generalizes the single-head split-K decode prototype in .scratch/sdpa-decode/mha_decode_playground.mojo (the twoshot path) to MHA + GQA against the in-tree MHAOperand / LayoutTensor / MHAMask contract.

Two kernels:

  • naive_fa_decode_apple_core β€” producer. Grid (num_partitions, batch_size, num_heads), 1-D threadgroup. Each block owns one split of the KV range for one (batch, head) and writes per-partition partials (o_partial, m_partial, l_partial) via online softmax over BN-wide KV tiles.
  • naive_fa_decode_apple_stitch β€” stitch. Grid (num_heads, batch_size), block depth. One thread per depth element; combines the contiguous per-partition partials into the final output with a log-sum-exp (LSE) reduction.

The host launcher naive_fa_decode_apple allocates the partials and enqueues both kernels; flash_attention_dispatch selects it for Apple decode behind the MODULAR_ENABLE_APPLE_NAIVE_FA_DECODE env flag (default off).

The online-softmax tiling and the LSE combine mirror the validated scratch prototype; only the I/O contract and the head/GQA indexing change.

Partial-buffer layout (partition-last / contiguous):

  • ml_idx(b, head, split) = (b*num_heads + head)*num_partitions + split
  • `o_idx(b, head, d, split) = ((b*num_heads + head)*depth + d)*num_partitions
    • split`

comptime values​

BN​

comptime BN = 16

NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM​

comptime NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM = 256

NEG_INF​

comptime NEG_INF = SIMD(-3.0000000000000001E+38)

Functions​