Mojo function
flare_mla_prefill
flare_mla_prefill[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, output_type: DType, softmax_type: DType, q_shape: DimList, //, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False](output: NDBuffer[output_type, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: NDBuffer[type, 3, origin, shape, strides], v: NDBuffer[type, 3, origin, shape, strides], k_rope: cache_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: NDBuffer[uint32, 1, origin, shape, strides], cache_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], scale: SIMD[float32, 1], ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}), softmax_info: OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]] = OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]]({:i1 0, 1}), cache_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]] = OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]]({:i1 0, 1}), prev_output: OptionalReg[NDBuffer[output_type, rank, MutableAnyOrigin]] = OptionalReg[NDBuffer[output_type, rank, MutableAnyOrigin]]({:i1 0, 1}), prev_softmax_info: OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]] = OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]]({:i1 0, 1}))
MLA prefill kernel that would only be called in the optimized compute graph. Only supports ragged Q/K/V inputs.
The Q input has a shape of [seq_len, num_heads, q_depth]. The K and V input has a shape of [cache_len, num_heads, depth]. The K_rope input is retrieved from the KV cache, with a shape of [cache_len, 1, q_depth - depth].
Specifically, for DeepSeek V2/3, depth = 128 and q_depth = 192.
When computing attention scores (Q @ K), each head of K is smaller than Q head. The missing 64 elements of K are retrieved from the K cache, and broadcasted to all the heads. This kernel also handles that output has reduced dimension compared to input Q.
This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.
flare_mla_prefill[rank: Int, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, softmax_type: DType, q_shape: DimList, //, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False](output: NDBuffer[type, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: NDBuffer[type, 3, origin, shape, strides], v: NDBuffer[type, 3, origin, shape, strides], k_rope: NDBuffer[type, 4, origin, shape, strides], mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: NDBuffer[uint32, 1, origin, shape, strides], cache_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], scale: SIMD[float32, 1], ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}), softmax_info: OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]] = OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]]({:i1 0, 1}), cache_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]] = OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]]({:i1 0, 1}))
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!