Skip to main content

Mojo module

mha_prefill

Unified gfx950 MHA prefill kernel.

Handles BF16+FP8, any mask, depth∈{64,128,256,512}, with and without sink. Consolidates the two older kernels (pipelined mha_prefill and non-pipelined mha_prefill_gfx950) into one, matching amd/mha.mojo from main but built on KVBuffer (our version of KVBuffer).

Recipe:

  • Double-buffered LDS slots (slot ping-pong) to overlap DRAM→LDS with MMA.
  • V registers loaded after QK (V-load-later) to reduce peak VGPR use.
  • Per-iteration: wait K → mma_qk → prefetch next tile → split softmax → wait V → mma_pv.
  • Softmax picks between prescaled / deferred-scale (_fma) / plain based on prescale_q + depth + mask.apply_log2e_after_mask.
  • Non-causal masks check for FULL_MASK tiles per-iteration and skip.

Sink is handled entirely in Attention.__init__ (rowmax/rowsum are pre-filled from sink_weights); the kernel body needs no sink-specific code.

Functions

Was this page helpful?