For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
naive_fa_decode
Apple (Metal) split-K naive flash-attention DECODE kernels.
Apple silicon GPU (Metal), decode-only (one query token per sequence), paged KV
cache via the MHAOperand contract, BF16 storage / FP32 accumulation.
Warp-centric producer: one simdgroup (32 lanes) owns one split of the KV range
for one (batch, head). Lane L owns the contiguous head-dim chunk
[L*EPL, L*EPL+EPL) where EPL = head_dim // WARP_SIZE; the query and running
output stay in registers, Q.K^T is reduced across lanes with one air.simd_sum
per key, and P.V is reduction-free. The inner loop has no barrier() and no
threadgroup memory β the two levers Apple silicon is most sensitive to.
Two kernels:
naive_fa_decode_apple_coreβ producer. Grid(num_partitions, batch_size, num_heads), block =WARP_SIZE(one simdgroup). Each block writes per-partition partials(o_partial, m_partial, l_partial)via online softmax overBN-wide KV tiles.naive_fa_decode_apple_stitchβ stitch. Grid(num_heads, batch_size), blockdepth. One thread per depth element; combines the contiguous per-partition partials into the finaloutputwith a log-sum-exp (LSE) reduction.
The host launcher naive_fa_decode_apple allocates the partials and enqueues
both kernels; flash_attention_dispatch selects it for Apple decode by default
(set MODULAR_ENABLE_APPLE_NAIVE_FA_DECODE=0 to opt out). The launcher
dispatches the runtime depth to a compile-time Depth specialization over the
multiples of WARP_SIZE up to NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM; the
dispatcher only routes here when depth % WARP_SIZE == 0 and
depth <= NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM, otherwise mha_gpu_naive runs.
Partial-buffer layout (partition-last / contiguous):
ml_idx(b, head, split) = (b*num_heads + head)*num_partitions + split- `o_idx(b, head, d, split) = ((b*num_heads + head)*depth + d)*num_partitions
- split`
comptime valuesβ
BNβ
comptime BN = 16
NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIMβ
comptime NAIVE_FA_DECODE_APPLE_MAX_HEAD_DIM = 256
NEG_INFβ
comptime NEG_INF = SIMD(-3.0000000000000001E+38)
Functionsβ
- β
naive_fa_decode_apple: Host launcher for the Apple split-K decode attention pair (decode-only). - β
naive_fa_decode_apple_core: Warp-centric split-K online-softmax producer for Apple decode attention. - β
naive_fa_decode_apple_stitch:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!