IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MhaPrefillV2

struct MhaPrefillV2[config: MhaConfigV2]

8-warp MHA forward kernel parameterized by MhaConfigV2.

Each block runs config.num_warps wave64 warps that share K/V SMEM via cooperative DMA. Warp w owns Q rows [w * q_block_size, (w + 1) * q_block_size) of the block's stripe and carries its own register-resident attention state.

Parameters​

  • ​config (MhaConfigV2): Shape configuration (MhaConfigV2).

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BM​

comptime BM = (MhaPrefillV2[config].NUM_WARPS * MhaPrefillV2[config].Q_BLOCK_SIZE)

D_FRAG_PER_LANE​

comptime D_FRAG_PER_LANE = ((MhaPrefillV2[config].DEPTH * MhaPrefillV2[config].Q_BLOCK_SIZE) // 64)

DEPTH​

comptime DEPTH = config.depth

k_swizzle​

comptime k_swizzle = Optional(Swizzle(1, 0, 4))

k_swizzle2​

comptime k_swizzle2 = Optional(Swizzle(1, 1, 4))

KTileLoader​

comptime KTileLoader = SubTileLoaderLDS[config.dtype, MhaPrefillV2[config].k_swizzle, MhaPrefillV2[config].k_swizzle2]

KV_BLOCK​

comptime KV_BLOCK = config.kv_block

NUM_HEADS​

comptime NUM_HEADS = config.num_heads

NUM_KV_HEADS​

comptime NUM_KV_HEADS = config.num_kv_heads

NUM_THREADS​

comptime NUM_THREADS = (MhaPrefillV2[config].NUM_WARPS * 64)

NUM_WARPS​

comptime NUM_WARPS = config.num_warps

prescale_q​

comptime prescale_q = not config.dtype.is_float8().__bool__()

Q_BLOCK_SIZE​

comptime Q_BLOCK_SIZE = config.q_block_size

RESCALE_THRESHOLD​

comptime RESCALE_THRESHOLD = config.rescale_threshold

v_swizzle​

comptime v_swizzle = Optional(None)

VTileLoader​

comptime VTileLoader = SubTileLoaderLDS_st_8x32[config.dtype, MhaPrefillV2[config].KV_BLOCK, MhaPrefillV2[config].DEPTH, 64 if config.dtype.is_float8() else 32, MhaPrefillV2[config].NUM_THREADS]

Methods​

load_q​

static def load_q[layout: TensorLayout](q_warp_2d: TileTensor[config.dtype, layout, address_space=q_warp_2d.address_space, linear_idx_type=q_warp_2d.linear_idx_type, element_size=q_warp_2d.element_size]) -> TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Loads the warp's Q sub-tile from gmem into a row_l register tile via RegTileLoader.

BF16 (d=128, MMA_K=16): 8 K-tiles Γ— 1 buffer_load_bf16x8 per lane per K-tile = 8 loads Γ— 16 B each. Per-lane fragment = 8 BF16 = 16 B fits in one buffer_load.

FP8 (d=128, MMA_K=64): 2 K-tiles, but each base tile per lane is 32 FP8 = 32 B which exceeds the 16-B buffer_load_lds max. Splits each K-tile load into 2 Γ— 16-elt halves (16 B each) targeting the first / second half of the destination cell.

Returns:

TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

run​

static def run[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, q_dtype: DType, output_dtype: DType, q_layout: TensorLayout, o_layout: TensorLayout, ragged: Bool = False, sink: Bool = False](q: TileTensor[q_dtype, q_layout, ImmutAnyOrigin], k: k_t, v: v_t, o: TileTensor[output_dtype, o_layout, MutAnyOrigin], mask_functor: mask_t, scale: Float32, num_keys: Int, start_pos: Int, sink_weights_ptr: UnsafePointer[Scalar[q_dtype], ImmutAnyOrigin])

Multi-block 8-warp MHA forward (inference-only).

Grid: (NUM_HEADS, ceildiv(seq_len, BM), batch). Each block owns one (batch, head, BM-tile) slice; the 8 warps within split the BM-tile's Q rows.

Expected layouts / shapes:

  • q, o: (batch, seq_len, NUM_HEADS, DEPTH) row-major TileTensor. o's dtype matches config.output_dtype β€” BF16 for the production dispatcher (which holds a BF16 output buffer) or FP32 if the caller wants the unnormalized accumulator.
  • k, v: any MHAOperand whose block_paged_tile[KV_BLOCK] returns (KV_BLOCK, DEPTH) tiles per (batch, t*KV_BLOCK, kv_head, 0). LayoutTensorMHAOperand for contiguous test / bench buffers; KVCacheMHAOperand for paged production caches (page_size >= KV_BLOCK = 64).

batch and seq_len / num_keys may be dynamic; NUM_HEADS, NUM_KV_HEADS, DEPTH must be static. NUM_HEADS must be a multiple of NUM_KV_HEADS (GROUP = NUM_HEADS // NUM_KV_HEADS).

Args:

  • ​q (TileTensor[q_dtype, q_layout, ImmutAnyOrigin]): Q tile tensor.
  • ​k (k_t): K operand (MHAOperand).
  • ​v (v_t): V operand (MHAOperand).
  • ​o (TileTensor[output_dtype, o_layout, MutAnyOrigin]): Output tile tensor (config.output_dtype, same shape as q).
  • ​mask_functor (mask_t): Per-tile mask predicate (causal, sliding-window, etc.). Evaluated inside the QKβ†’softmax cluster; identity for unmasked attention.
  • ​scale (Float32): Softmax scale (typically 1 / sqrt(DEPTH)).
  • ​num_keys (Int): Runtime length of the K/V sequence.
  • ​start_pos (Int): Position of the first Q row in the global sequence β€” non-zero for prefill chunks of a longer generation. Used by the mask functor to compute the causal cutoff.
  • ​sink_weights_ptr (UnsafePointer[Scalar[q_dtype], ImmutAnyOrigin]): Per-q-head attention-sink scalar weights. Read only when the comptime sink parameter is True; the non-sink path comptime-elides the load, so callers may pass UnsafePointer[...].unsafe_dangling() when sink=False. Indexed by head_idx once per block at init time, cast to FP32, multiplied by log2e to land in the kernel's log2-units rowmax, and seeded into max_vec / max_vec_prev / norm_vec so the hot loop stays sink-agnostic.

ragged_kernel​

static def ragged_kernel[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, qkv_dtype: DType, output_dtype: DType, cross_attention: Bool = False, sink: Bool = False](q_ptr: UnsafePointer[Scalar[qkv_dtype], ImmutAnyOrigin], k: k_t, v: v_t, output_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin], mask_functor: mask_t, scale: Float32, input_row_offsets_ptr: UnsafePointer[UInt32, ImmutAnyOrigin], kv_input_row_offsets_ptr: UnsafePointer[UInt32, ImmutAnyOrigin], sink_weights_ptr: UnsafePointer[Scalar[qkv_dtype], ImmutAnyOrigin])

Ragged-batch GPU kernel entry: per-sequence setup + run.

The non-ragged equivalent is run itself (which takes already-sliced per-batch TileTensors). For ragged, this wrapper does the per-block ragged setup so the launcher can pass a single packed Q pointer + input_row_offsets.

cross_attention=False (default): self-attention, where K/V length equals Q length plus any cached prefix. num_keys derives from start_pos + seq_len. kv_input_row_offsets_ptr is unused (caller may pass any well-typed stub).

cross_attention=True: encoder-decoder style. K/V lengths come from kv_input_row_offsets_ptr, independent of the Q-side offsets. Mirrors the FA2 contract at mha.mojo:1755-1762.