IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MlaPrefillV2

struct MlaPrefillV2[config: MlaConfigV2]

Fresh single-schedule port of the reference integrated MLA-prefill inner loop for gfx950. 1 wave / EU; resident Q; single FP32 score tile with in-place FP8 P collapse; shared K/V band; 64-VGPR FP32 O accumulator with eager rescale; reference work-split K/V DMA + deep even wave-spec stagger over a 160 KB LDS.

AMD FP8 MLA-prefill kernel (FP8 / KV>=128 / 32x32x64). The reused numeric closure lives in mla_components.mojo (MlaPrefillV2Core); this file consumes only that module. The single reference-exact inner loop (_attend_exact) is described in the module docstring.

Parameters​

  • ​config (MlaConfigV2): Shape configuration (MlaConfigV2). The FP8 KV=128 DeepSeek-V3 MLA shape (q_block=32, kv_block=128, depth=128, d_qk=192, d_rope=64, cache_depth=576) is the target shape.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BM​

comptime BM = (config * config)

D_QK​

comptime D_QK = config.d_qk

DEPTH​

comptime DEPTH = config.depth

KV_BLOCK​

comptime KV_BLOCK = config.kv_block

NUM_HEADS​

comptime NUM_HEADS = config.num_heads

NUM_KV_HEADS​

comptime NUM_KV_HEADS = config.num_kv_heads

NUM_THREADS​

comptime NUM_THREADS = (config * Int(64))

NUM_WARPS​

comptime NUM_WARPS = config.num_warps

Q_BLOCK_SIZE​

comptime Q_BLOCK_SIZE = config.q_block_size

Methods​

run​

static def run[k_nope_t: MHAOperand, k_rope_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, q_dtype: DType, output_dtype: DType, q_layout: TensorLayout, o_layout: TensorLayout, ragged: Bool = False](q: TileTensor[q_dtype, q_layout, ImmutAnyOrigin], k_nope_op: k_nope_t, k_rope_op: k_rope_t, v_op: v_t, o: TileTensor[output_dtype, o_layout, MutAnyOrigin], mask_functor: mask_t, scale: Float32, num_keys: Int, start_pos: Int, work_indptr_ptr: UnsafePointer[Int32, ImmutAnyOrigin] = UnsafePointer.unsafe_dangling(), work_info_ptr: UnsafePointer[Int32, ImmutAnyOrigin] = UnsafePointer.unsafe_dangling(), num_works: Int = Int(0))

Multi-block 8-warp MLA forward β€” reference integrated cadence.

Grid: (NUM_HEADS, ceildiv(seq_len, BM), batch). Each block owns one (batch, head, BM-tile) slice; the 8 warps split the BM-tile's Q rows. Same grid/operand contract as MlaPrefillV2Core.run.

Args:

  • ​q (TileTensor[q_dtype, q_layout, ImmutAnyOrigin]): Q tile tensor at d_qk = d_nope + d_rope.
  • ​k_nope_op (k_nope_t): K (nope segment) operand, head_dim_idx=0. Also serves as the single-base K loader source (the _MlaKDmaPair reads both nope cols [0, D_NOPE) and rope cols [ROPE_CACHE_OFFSET, +D_ROPE) from this one operand).
  • ​k_rope_op (k_rope_t): K (rope segment) operand (unused here β€” the unified _MlaKDmaPair slices rope from k_nope_op's full latent-cache row; kept in the signature for contract parity with MlaPrefillV2Core.run and the dispatcher).
  • ​v_op (v_t): V operand (= nope segment of the latent cache), head_dim_idx=0.
  • ​o (TileTensor[output_dtype, o_layout, MutAnyOrigin]): Output tile tensor at d_pv = depth.
  • ​mask_functor (mask_t): Per-tile mask predicate (causal / null / ...).
  • ​scale (Float32): Softmax scale (typically 1 / sqrt(d_qk)).
  • ​num_keys (Int): Runtime length of the K/V sequence.
  • ​start_pos (Int): Position of the first Q row in the global sequence.
  • ​work_indptr_ptr (UnsafePointer[Int32, ImmutAnyOrigin]): Persistent-prefill work partition prefix-sum [num_cu+1] (device). Threaded for the S2 persistent grid; unused by the current static grid.
  • ​work_info_ptr (UnsafePointer[Int32, ImmutAnyOrigin]): Persistent-prefill flat WorkInfo array [num_works*8] int32 (device). Threaded for S2; unused by the current static grid.
  • ​num_works (Int): Total number of work-tiles in work_info_ptr. Threaded for S2; unused by the current static grid.

ragged_kernel​

static def ragged_kernel[k_nope_t: MHAOperand, k_rope_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, qkv_dtype: DType, output_dtype: DType](q_ptr: UnsafePointer[Scalar[qkv_dtype], ImmutAnyOrigin], k_nope_op: k_nope_t, k_rope_op: k_rope_t, v_op: v_t, output_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin], mask_functor: mask_t, scale: Float32, input_row_offsets_ptr: UnsafePointer[UInt32, ImmutAnyOrigin])

Ragged-batch GPU kernel entry. Per-sequence setup mirrors MlaPrefillV2Core.ragged_kernel (self-attention; num_keys = start_pos + seq_len).