IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

flare_mla_prefill

def flare_mla_prefill[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, output_type: DType, //](output: TileTensor[output_type, Storage=output.Storage, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, Storage=q.Storage, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: LayoutTensor[element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], k_rope: cache_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, Storage=valid_length.Storage, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, Storage=cache_row_offsets.Storage, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(Int(-1)), MutAnyOrigin]] = None)

MLA prefill kernel that would only be called in the optimized compute graph. Only supports ragged Q/K/V inputs.

The Q input has a shape of [seq_len, num_heads, q_depth]. The K and V input has a shape of [cache_len, num_heads, depth]. The K_rope input is retrieved from the KV cache, with a shape of [cache_len, 1, q_depth - depth].

Specifically, for DeepSeek V2/3, depth = 128 and q_depth = 192.

When computing attention scores (Q @ K), each head of K is smaller than Q head. The missing 64 elements of K are retrieved from the K cache, and broadcasted to all the heads. This kernel also handles that output has reduced dimension compared to input Q.

This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.

def flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, //](output: TileTensor[Storage=output.Storage, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, Storage=q.Storage, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[Storage=k.Storage, linear_idx_type=k.linear_idx_type, element_size=k.element_size], v: TileTensor[Storage=v.Storage, linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: TileTensor[Storage=k_rope.Storage, linear_idx_type=k_rope.linear_idx_type, element_size=k_rope.element_size], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, Storage=valid_length.Storage, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, Storage=cache_row_offsets.Storage, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(Int(-1)), MutAnyOrigin]] = None)

def flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, //](output: TileTensor[Storage=output.Storage, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, Storage=q.Storage, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[Storage=k.Storage, linear_idx_type=k.linear_idx_type, element_size=k.element_size], v: TileTensor[Storage=v.Storage, linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: TileTensor[Storage=k_rope.Storage, linear_idx_type=k_rope.linear_idx_type, element_size=k_rope.element_size], k_rope_scales: TileTensor[Storage=k_rope_scales.Storage, linear_idx_type=k_rope_scales.linear_idx_type, element_size=k_rope_scales.element_size], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, Storage=valid_length.Storage, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, Storage=cache_row_offsets.Storage, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(Int(-1)), MutAnyOrigin]] = None)

def flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, scale_dtype: DType, //](output: TileTensor[Storage=output.Storage, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q_nope: TileTensor[dtype, Storage=q_nope.Storage, linear_idx_type=q_nope.linear_idx_type, element_size=q_nope.element_size], q_rope: TileTensor[Storage=q_rope.Storage, linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size], q_scale: TileTensor[scale_dtype, Storage=q_scale.Storage, linear_idx_type=q_scale.linear_idx_type, element_size=q_scale.element_size], k: TileTensor[Storage=k.Storage, linear_idx_type=k.linear_idx_type, element_size=k.element_size], k_scales: TileTensor[scale_dtype, Storage=k_scales.Storage, linear_idx_type=k_scales.linear_idx_type, element_size=k_scales.element_size], v: TileTensor[Storage=v.Storage, linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: TileTensor[Storage=k_rope.Storage, linear_idx_type=k_rope.linear_idx_type, element_size=k_rope.element_size], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, Storage=valid_length.Storage, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, Storage=cache_row_offsets.Storage, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(Int(-1)), MutAnyOrigin]] = None)

def flare_mla_prefill[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, scale_dtype: DType, //](output: TileTensor[Storage=output.Storage, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q_nope: TileTensor[dtype, Storage=q_nope.Storage, linear_idx_type=q_nope.linear_idx_type, element_size=q_nope.element_size], q_rope: TileTensor[Storage=q_rope.Storage, linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size], q_scale: TileTensor[scale_dtype, Storage=q_scale.Storage, linear_idx_type=q_scale.linear_idx_type, element_size=q_scale.element_size], k: TileTensor[Storage=k.Storage, linear_idx_type=k.linear_idx_type, element_size=k.element_size], k_scales: TileTensor[scale_dtype, Storage=k_scales.Storage, linear_idx_type=k_scales.linear_idx_type, element_size=k_scales.element_size], v: TileTensor[Storage=v.Storage, linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: cache_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, Storage=valid_length.Storage, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, Storage=cache_row_offsets.Storage, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(Int(-1)), MutAnyOrigin]] = None)