For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MhaPrefillV2
struct MhaPrefillV2[config: MhaConfigV2]
8-warp MHA forward kernel parameterized by MhaConfigV2.
Each block runs config.num_warps wave64 warps that share K/V
SMEM via cooperative DMA. Warp w owns Q rows [w * q_block_size, (w + 1) * q_block_size) of the block's stripe and carries its own
register-resident attention state.
Parametersβ
- βconfig (
MhaConfigV2): Shape configuration (MhaConfigV2).
Implemented traitsβ
comptime membersβ
BMβ
comptime BM = (MhaPrefillV2[config].NUM_WARPS * MhaPrefillV2[config].Q_BLOCK_SIZE)
D_FRAG_PER_LANEβ
comptime D_FRAG_PER_LANE = ((MhaPrefillV2[config].DEPTH * MhaPrefillV2[config].Q_BLOCK_SIZE) // 64)
DEPTHβ
comptime DEPTH = config.depth
k_swizzleβ
comptime k_swizzle = Optional(Swizzle(1, 0, 4))
k_swizzle2β
comptime k_swizzle2 = Optional(Swizzle(1, 1, 4))
KTileLoaderβ
comptime KTileLoader = SubTileLoaderLDS[config.dtype, MhaPrefillV2[config].k_swizzle, MhaPrefillV2[config].k_swizzle2]
KV_BLOCKβ
comptime KV_BLOCK = config.kv_block
NUM_HEADSβ
comptime NUM_HEADS = config.num_heads
NUM_KV_HEADSβ
comptime NUM_KV_HEADS = config.num_kv_heads
NUM_THREADSβ
comptime NUM_THREADS = (MhaPrefillV2[config].NUM_WARPS * 64)
NUM_WARPSβ
comptime NUM_WARPS = config.num_warps
prescale_qβ
comptime prescale_q = not config.dtype.is_float8().__bool__()
Q_BLOCK_SIZEβ
comptime Q_BLOCK_SIZE = config.q_block_size
RESCALE_THRESHOLDβ
comptime RESCALE_THRESHOLD = config.rescale_threshold
v_swizzleβ
comptime v_swizzle = Optional(None)
VTileLoaderβ
comptime VTileLoader = SubTileLoaderLDS_st_8x32[config.dtype, MhaPrefillV2[config].KV_BLOCK, MhaPrefillV2[config].DEPTH, 64 if config.dtype.is_float8() else 32, MhaPrefillV2[config].NUM_THREADS]
Methodsβ
load_qβ
static def load_q[layout: TensorLayout](q_warp_2d: TileTensor[config.dtype, layout, address_space=q_warp_2d.address_space, linear_idx_type=q_warp_2d.linear_idx_type, element_size=q_warp_2d.element_size]) -> TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Loads the warp's Q sub-tile from gmem into a row_l register tile via RegTileLoader.
BF16 (d=128, MMA_K=16): 8 K-tiles Γ 1 buffer_load_bf16x8 per
lane per K-tile = 8 loads Γ 16 B each. Per-lane fragment = 8
BF16 = 16 B fits in one buffer_load.
FP8 (d=128, MMA_K=64): 2 K-tiles, but each base tile per lane is 32 FP8 = 32 B which exceeds the 16-B buffer_load_lds max. Splits each K-tile load into 2 Γ 16-elt halves (16 B each) targeting the first / second half of the destination cell.
Returns:
TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
runβ
static def run[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, q_dtype: DType, output_dtype: DType, q_layout: TensorLayout, o_layout: TensorLayout, ragged: Bool = False, sink: Bool = False](q: TileTensor[q_dtype, q_layout, ImmutAnyOrigin], k: k_t, v: v_t, o: TileTensor[output_dtype, o_layout, MutAnyOrigin], mask_functor: mask_t, scale: Float32, num_keys: Int, start_pos: Int, sink_weights_ptr: UnsafePointer[Scalar[q_dtype], ImmutAnyOrigin])
Multi-block 8-warp MHA forward (inference-only).
Grid: (NUM_HEADS, ceildiv(seq_len, BM), batch). Each block
owns one (batch, head, BM-tile) slice; the 8 warps within
split the BM-tile's Q rows.
Expected layouts / shapes:
q,o:(batch, seq_len, NUM_HEADS, DEPTH)row-major TileTensor.o's dtype matchesconfig.output_dtypeβ BF16 for the production dispatcher (which holds a BF16 output buffer) or FP32 if the caller wants the unnormalized accumulator.k,v: anyMHAOperandwhoseblock_paged_tile[KV_BLOCK]returns(KV_BLOCK, DEPTH)tiles per(batch, t*KV_BLOCK, kv_head, 0).LayoutTensorMHAOperandfor contiguous test / bench buffers;KVCacheMHAOperandfor paged production caches (page_size >= KV_BLOCK = 64).
batch and seq_len / num_keys may be dynamic;
NUM_HEADS, NUM_KV_HEADS, DEPTH must be static.
NUM_HEADS must be a multiple of NUM_KV_HEADS
(GROUP = NUM_HEADS // NUM_KV_HEADS).
Args:
- βq (
TileTensor[q_dtype, q_layout, ImmutAnyOrigin]): Q tile tensor. - βk (
k_t): K operand (MHAOperand). - βv (
v_t): V operand (MHAOperand). - βo (
TileTensor[output_dtype, o_layout, MutAnyOrigin]): Output tile tensor (config.output_dtype, same shape asq). - βmask_functor (
mask_t): Per-tile mask predicate (causal, sliding-window, etc.). Evaluated inside the QKβsoftmax cluster; identity for unmasked attention. - βscale (
Float32): Softmax scale (typically1 / sqrt(DEPTH)). - βnum_keys (
Int): Runtime length of the K/V sequence. - βstart_pos (
Int): Position of the first Q row in the global sequence β non-zero for prefill chunks of a longer generation. Used by the mask functor to compute the causal cutoff. - βsink_weights_ptr (
UnsafePointer[Scalar[q_dtype], ImmutAnyOrigin]): Per-q-head attention-sink scalar weights. Read only when the comptimesinkparameter is True; the non-sink path comptime-elides the load, so callers may passUnsafePointer[...].unsafe_dangling()whensink=False. Indexed byhead_idxonce per block at init time, cast to FP32, multiplied bylog2eto land in the kernel's log2-units rowmax, and seeded intomax_vec/max_vec_prev/norm_vecso the hot loop stays sink-agnostic.
ragged_kernelβ
static def ragged_kernel[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, qkv_dtype: DType, output_dtype: DType, cross_attention: Bool = False, sink: Bool = False](q_ptr: UnsafePointer[Scalar[qkv_dtype], ImmutAnyOrigin], k: k_t, v: v_t, output_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin], mask_functor: mask_t, scale: Float32, input_row_offsets_ptr: UnsafePointer[UInt32, ImmutAnyOrigin], kv_input_row_offsets_ptr: UnsafePointer[UInt32, ImmutAnyOrigin], sink_weights_ptr: UnsafePointer[Scalar[qkv_dtype], ImmutAnyOrigin])
Ragged-batch GPU kernel entry: per-sequence setup + run.
The non-ragged equivalent is run itself (which takes
already-sliced per-batch TileTensors). For ragged, this
wrapper does the per-block ragged setup so the launcher can
pass a single packed Q pointer + input_row_offsets.
cross_attention=False (default): self-attention, where K/V
length equals Q length plus any cached prefix. num_keys
derives from start_pos + seq_len. kv_input_row_offsets_ptr
is unused (caller may pass any well-typed stub).
cross_attention=True: encoder-decoder style. K/V lengths come
from kv_input_row_offsets_ptr, independent of the Q-side
offsets. Mirrors the FA2 contract at mha.mojo:1755-1762.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!