Mojo function
flare_mla_prefill
flare_mla_prefill[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, output_type: DType, //](output: TileTensor[output_type, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: LayoutTensor[k.dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[v.dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], k_rope: cache_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, cache_row_offsets.LayoutType, cache_row_offsets.origin, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), MutAnyOrigin]] = None)
MLA prefill kernel that would only be called in the optimized compute graph. Only supports ragged Q/K/V inputs.
The Q input has a shape of [seq_len, num_heads, q_depth]. The K and V input has a shape of [cache_len, num_heads, depth]. The K_rope input is retrieved from the KV cache, with a shape of [cache_len, 1, q_depth - depth].
Specifically, for DeepSeek V2/3, depth = 128 and q_depth = 192.
When computing attention scores (Q @ K), each head of K is smaller than Q head. The missing 64 elements of K are retrieved from the K cache, and broadcasted to all the heads. This kernel also handles that output has reduced dimension compared to input Q.
This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.
flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, //](output: TileTensor[output.dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: LayoutTensor[k.dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[v.dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], k_rope: LayoutTensor[k_rope.dtype, k_rope.layout, k_rope.origin, element_layout=k_rope.element_layout, layout_int_type=k_rope.layout_int_type, linear_idx_type=k_rope.linear_idx_type, masked=k_rope.masked, alignment=k_rope.alignment], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: LayoutTensor[DType.uint32, cache_row_offsets.layout, cache_row_offsets.origin, element_layout=cache_row_offsets.element_layout, layout_int_type=cache_row_offsets.layout_int_type, linear_idx_type=cache_row_offsets.linear_idx_type, masked=cache_row_offsets.masked, alignment=cache_row_offsets.alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), MutAnyOrigin]] = None)
flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, //](output: TileTensor[output.dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: LayoutTensor[k.dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[v.dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], k_rope: LayoutTensor[k_rope.dtype, k_rope.layout, k_rope.origin, element_layout=k_rope.element_layout, layout_int_type=k_rope.layout_int_type, linear_idx_type=k_rope.linear_idx_type, masked=k_rope.masked, alignment=k_rope.alignment], k_rope_scales: LayoutTensor[k_rope_scales.dtype, k_rope_scales.layout, k_rope_scales.origin, element_layout=k_rope_scales.element_layout, layout_int_type=k_rope_scales.layout_int_type, linear_idx_type=k_rope_scales.linear_idx_type, masked=k_rope_scales.masked, alignment=k_rope_scales.alignment], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: LayoutTensor[DType.uint32, cache_row_offsets.layout, cache_row_offsets.origin, element_layout=cache_row_offsets.element_layout, layout_int_type=cache_row_offsets.layout_int_type, linear_idx_type=cache_row_offsets.linear_idx_type, masked=cache_row_offsets.masked, alignment=cache_row_offsets.alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), MutAnyOrigin]] = None)
flare_mla_prefill[rank: Int, mask_t: MHAMask, dtype: DType, scale_dtype: DType, //](output: TileTensor[output.dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q_nope: TileTensor[dtype, q_nope.LayoutType, q_nope.origin, linear_idx_type=q_nope.linear_idx_type, element_size=q_nope.element_size], q_rope: LayoutTensor[q_rope.dtype, q_rope.layout, q_rope.origin, element_layout=q_rope.element_layout, layout_int_type=q_rope.layout_int_type, linear_idx_type=q_rope.linear_idx_type, masked=q_rope.masked, alignment=q_rope.alignment], q_scale: LayoutTensor[scale_dtype, q_scale.layout, q_scale.origin, element_layout=q_scale.element_layout, layout_int_type=q_scale.layout_int_type, linear_idx_type=q_scale.linear_idx_type, masked=q_scale.masked, alignment=q_scale.alignment], k: LayoutTensor[k.dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], k_scales: LayoutTensor[scale_dtype, k_scales.layout, k_scales.origin, element_layout=k_scales.element_layout, layout_int_type=k_scales.layout_int_type, linear_idx_type=k_scales.linear_idx_type, masked=k_scales.masked, alignment=k_scales.alignment], v: LayoutTensor[v.dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], k_rope: LayoutTensor[k_rope.dtype, k_rope.layout, k_rope.origin, element_layout=k_rope.element_layout, layout_int_type=k_rope.layout_int_type, linear_idx_type=k_rope.linear_idx_type, masked=k_rope.masked, alignment=k_rope.alignment], mask_functor: mask_t, valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: LayoutTensor[DType.uint32, cache_row_offsets.layout, cache_row_offsets.origin, element_layout=cache_row_offsets.element_layout, layout_int_type=cache_row_offsets.layout_int_type, linear_idx_type=cache_row_offsets.linear_idx_type, masked=cache_row_offsets.masked, alignment=cache_row_offsets.alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), MutAnyOrigin]] = None)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!