Skip to main content

Mojo function

flare_mla_prefill

flare_mla_prefill[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, output_type: DType, //](output: TileTensor[output_type, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: LayoutTensor[element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], k_rope: cache_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None)

MLA prefill kernel that would only be called in the optimized compute graph. Only supports ragged Q/K/V inputs.

The Q input has a shape of [seq_len, num_heads, q_depth]. The K and V input has a shape of [cache_len, num_heads, depth]. The K_rope input is retrieved from the KV cache, with a shape of [cache_len, 1, q_depth - depth].

Specifically, for DeepSeek V2/3, depth = 128 and q_depth = 192.

When computing attention scores (Q @ K), each head of K is smaller than Q head. The missing 64 elements of K are retrieved from the K cache, and broadcasted to all the heads. This kernel also handles that output has reduced dimension compared to input Q.

This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.

flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, //](output: TileTensor[linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[linear_idx_type=k.linear_idx_type, element_size=k.element_size], v: TileTensor[linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: TileTensor[linear_idx_type=k_rope.linear_idx_type, element_size=k_rope.element_size], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None)

flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, //](output: TileTensor[linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[linear_idx_type=k.linear_idx_type, element_size=k.element_size], v: TileTensor[linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: TileTensor[linear_idx_type=k_rope.linear_idx_type, element_size=k_rope.element_size], k_rope_scales: TileTensor[linear_idx_type=k_rope_scales.linear_idx_type, element_size=k_rope_scales.element_size], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None)

flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, scale_dtype: DType, //](output: TileTensor[linear_idx_type=output.linear_idx_type, element_size=output.element_size], q_nope: TileTensor[dtype, linear_idx_type=q_nope.linear_idx_type, element_size=q_nope.element_size], q_rope: TileTensor[linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size], q_scale: TileTensor[scale_dtype, linear_idx_type=q_scale.linear_idx_type, element_size=q_scale.element_size], k: TileTensor[linear_idx_type=k.linear_idx_type, element_size=k.element_size], k_scales: TileTensor[scale_dtype, linear_idx_type=k_scales.linear_idx_type, element_size=k_scales.element_size], v: TileTensor[linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: TileTensor[linear_idx_type=k_rope.linear_idx_type, element_size=k_rope.element_size], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None)

flare_mla_prefill[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, scale_dtype: DType, //](output: TileTensor[linear_idx_type=output.linear_idx_type, element_size=output.element_size], q_nope: TileTensor[dtype, linear_idx_type=q_nope.linear_idx_type, element_size=q_nope.element_size], q_rope: TileTensor[linear_idx_type=q_rope.linear_idx_type, element_size=q_rope.element_size], q_scale: TileTensor[scale_dtype, linear_idx_type=q_scale.linear_idx_type, element_size=q_scale.element_size], k: TileTensor[linear_idx_type=k.linear_idx_type, element_size=k.element_size], k_scales: TileTensor[scale_dtype, linear_idx_type=k_scales.linear_idx_type, element_size=k_scales.element_size], v: TileTensor[linear_idx_type=v.linear_idx_type, element_size=v.element_size], k_rope: cache_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None)