IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

naive_fa_decode

Apple (Metal) split-K naive flash-attention DECODE kernels.

Apple silicon GPU (Metal), decode-only (one query token per sequence), paged KV cache via the MHAOperand contract, BF16 storage / FP32 accumulation.

Warp-centric producer: one simdgroup (32 lanes) owns one split of the KV range for one (batch, head). Lane L owns the contiguous head-dim chunk [L*EPL, L*EPL+EPL) where EPL = head_dim // WARP_SIZE; the query and running output stay in registers, Q.K^T is reduced across lanes with one air.simd_sum per key, and P.V is reduction-free. The inner loop has no barrier() and no threadgroup memory β€” the two levers Apple silicon is most sensitive to.

Two kernels:

  • naive_fa_decode_apple_core β€” producer. Grid (num_partitions, batch_size, num_heads), block = WARP_SIZE (one simdgroup). Each block writes per-partition partials (o_partial, m_partial, l_partial) via online softmax over BN-wide KV tiles.
  • naive_fa_decode_apple_stitch β€” stitch. Grid (num_heads, batch_size), block depth. One thread per depth element; combines the contiguous per-partition partials into the final output with a log-sum-exp (LSE) reduction.

The host launcher naive_fa_decode_apple allocates the partials and enqueues both kernels; flash_attention_dispatch selects it for Apple decode by default (set MODULAR_ENABLE_APPLE_NAIVE_FA_DECODE=0 to opt out). The launcher dispatches the runtime depth to a compile-time Depth specialization over the multiples of WARP_SIZE up to NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM; the dispatcher only routes here when depth % WARP_SIZE == 0 and depth <= NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM, otherwise mha_gpu_naive runs.

Partial-buffer layout (partition-last / contiguous):

  • ml_idx(b, head, split) = (b*num_heads + head)*num_partitions + split
  • `o_idx(b, head, d, split) = ((b*num_heads + head)*depth + d)*num_partitions
    • split`

comptime values​

BN​

comptime BN = 16

NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM​

comptime NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM = 256

NEG_INF​

comptime NEG_INF = SIMD(-3.0000000000000001E+38)

Functions​