IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mha_prefill

Unified gfx950 MHA prefill kernel.

Handles BF16+FP8, any mask, depth∈{64,128,256,512}, with and without sink. Consolidates the two older kernels (pipelined mha_prefill and non-pipelined mha_prefill_gfx950) into one, matching amd/mha.mojo from main but built on KVBuffer (our version of KVBuffer).

Recipe:

  • Double-buffered LDS slots (slot ping-pong) to overlap DRAM→LDS with MMA.
  • V registers loaded after QK (V-load-later) to reduce peak VGPR use.
  • Per-iteration: wait K → mma_qk → prefetch next tile → split softmax → wait V → mma_pv.
  • Softmax picks between prescaled / deferred-scale (_fma) / plain based on prescale_q + depth + mask.apply_log2e_after_mask.
  • Non-causal masks check for FULL_MASK tiles per-iteration and skip.

Sink is handled entirely in Attention.__init__ (rowmax/rowsum are pre-filled from sink_weights); the kernel body needs no sink-specific code.

Functions