For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MlaPrefillV2
struct MlaPrefillV2[config: MlaConfigV2]
Fresh single-schedule port of the reference integrated MLA-prefill inner loop for gfx950. 1 wave / EU; resident Q; single FP32 score tile with in-place FP8 P collapse; shared K/V band; 64-VGPR FP32 O accumulator with eager rescale; reference work-split K/V DMA + deep even wave-spec stagger over a 160 KB LDS.
AMD FP8 MLA-prefill kernel (FP8 / KV>=128 / 32x32x64). The reused
numeric closure lives in mla_components.mojo
(MlaPrefillV2Core); this file consumes only that module.
The single reference-exact inner loop (_attend_exact) is
described in the module docstring.
Parametersβ
- βconfig (
MlaConfigV2): Shape configuration (MlaConfigV2). The FP8 KV=128 DeepSeek-V3 MLA shape (q_block=32, kv_block=128, depth=128, d_qk=192, d_rope=64, cache_depth=576) is the target shape.
Implemented traitsβ
comptime membersβ
BMβ
comptime BM = (config * config)
D_QKβ
comptime D_QK = config.d_qk
DEPTHβ
comptime DEPTH = config.depth
KV_BLOCKβ
comptime KV_BLOCK = config.kv_block
NUM_HEADSβ
comptime NUM_HEADS = config.num_heads
NUM_KV_HEADSβ
comptime NUM_KV_HEADS = config.num_kv_heads
NUM_THREADSβ
comptime NUM_THREADS = (config * Int(64))
NUM_WARPSβ
comptime NUM_WARPS = config.num_warps
Q_BLOCK_SIZEβ
comptime Q_BLOCK_SIZE = config.q_block_size
Methodsβ
runβ
static def run[k_nope_t: MHAOperand, k_rope_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, q_dtype: DType, output_dtype: DType, q_layout: TensorLayout, o_layout: TensorLayout, ragged: Bool = False](q: TileTensor[q_dtype, q_layout, ImmutAnyOrigin], k_nope_op: k_nope_t, k_rope_op: k_rope_t, v_op: v_t, o: TileTensor[output_dtype, o_layout, MutAnyOrigin], mask_functor: mask_t, scale: Float32, num_keys: Int, start_pos: Int, work_indptr_ptr: UnsafePointer[Int32, ImmutAnyOrigin] = UnsafePointer.unsafe_dangling(), work_info_ptr: UnsafePointer[Int32, ImmutAnyOrigin] = UnsafePointer.unsafe_dangling(), num_works: Int = Int(0))
Multi-block 8-warp MLA forward β reference integrated cadence.
Grid: (NUM_HEADS, ceildiv(seq_len, BM), batch). Each block owns
one (batch, head, BM-tile) slice; the 8 warps split the
BM-tile's Q rows. Same grid/operand contract as
MlaPrefillV2Core.run.
Args:
- βq (
TileTensor[q_dtype, q_layout, ImmutAnyOrigin]): Q tile tensor at d_qk = d_nope + d_rope. - βk_nope_op (
k_nope_t): K (nope segment) operand,head_dim_idx=0. Also serves as the single-base K loader source (the_MlaKDmaPairreads both nope cols [0, D_NOPE) and rope cols [ROPE_CACHE_OFFSET, +D_ROPE) from this one operand). - βk_rope_op (
k_rope_t): K (rope segment) operand (unused here β the unified_MlaKDmaPairslices rope fromk_nope_op's full latent-cache row; kept in the signature for contract parity withMlaPrefillV2Core.runand the dispatcher). - βv_op (
v_t): V operand (= nope segment of the latent cache),head_dim_idx=0. - βo (
TileTensor[output_dtype, o_layout, MutAnyOrigin]): Output tile tensor at d_pv = depth. - βmask_functor (
mask_t): Per-tile mask predicate (causal / null / ...). - βscale (
Float32): Softmax scale (typically1 / sqrt(d_qk)). - βnum_keys (
Int): Runtime length of the K/V sequence. - βstart_pos (
Int): Position of the first Q row in the global sequence. - βwork_indptr_ptr (
UnsafePointer[Int32, ImmutAnyOrigin]): Persistent-prefill work partition prefix-sum[num_cu+1](device). Threaded for the S2 persistent grid; unused by the current static grid. - βwork_info_ptr (
UnsafePointer[Int32, ImmutAnyOrigin]): Persistent-prefill flatWorkInfoarray[num_works*8]int32 (device). Threaded for S2; unused by the current static grid. - βnum_works (
Int): Total number of work-tiles inwork_info_ptr. Threaded for S2; unused by the current static grid.
ragged_kernelβ
static def ragged_kernel[k_nope_t: MHAOperand, k_rope_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, qkv_dtype: DType, output_dtype: DType](q_ptr: UnsafePointer[Scalar[qkv_dtype], ImmutAnyOrigin], k_nope_op: k_nope_t, k_rope_op: k_rope_t, v_op: v_t, output_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin], mask_functor: mask_t, scale: Float32, input_row_offsets_ptr: UnsafePointer[UInt32, ImmutAnyOrigin])
Ragged-batch GPU kernel entry. Per-sequence setup mirrors MlaPrefillV2Core.ragged_kernel (self-attention; num_keys = start_pos + seq_len).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!