Skip to main content

Mojo function

flare_mla_decoding

flare_mla_decoding[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, //, config: MHAConfig[dtype], ragged: Bool = False, decoding_warp_split_k: Bool = False, per_token_scale_rope_aware: Bool = False, sparse: Bool = False](output: TileTensor[output.dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: cache_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], scale: Float32, ctx: DeviceContext, scalar_args_buf: TileTensor[DType.int64, scalar_args_buf.LayoutType, scalar_args_buf.origin, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], q_max_seq_len: OptionalReg[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, num_partitions: Optional[Int] = None, q_scale_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, indices_stride: Int = 0, topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, extra_k: OptionalReg[cache_t] = None, extra_d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_indices_stride: Int = 0, extra_topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_scales_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None)

MLA decoding kernel that would only be called in the optimized compute graph.

The Q input has a shape of [seq_len, num_heads, depth]. The K input has a shape of [seq_len, 1, depth]. The V tensor is derived by reusing K, where V = K[:, :, :depth_v].

Specifically, for DeepSeek V2/3, depth = 576 and depth_v = 512.

When per_token_scale_rope_aware is True, Q and KV cache have an interleaved FP8+BF16 layout: FP8 content (512 bytes) + BF16 rope (128 bytes) = 640 bytes/row. Q's last dimension is 640 (FP8 elements) but represents 576 logical dimensions (512 nope + 64 rope).

This kernel computes attention without needing to load V twice. This kernel only handles decoding requests. In this case q_max_seq_len = 1.

This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.

flare_mla_decoding[mask_t: MHAMask, dtype: DType, //, config: MHAConfig[dtype], decoding_warp_split_k: Bool = False](output: TileTensor[output.dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[k.dtype, k.LayoutType, k.origin, linear_idx_type=k.linear_idx_type, element_size=k.element_size], mask_functor: mask_t, scale: Float32, ctx: DeviceContext, scalar_args_buf: TileTensor[DType.int64, scalar_args_buf.LayoutType, scalar_args_buf.origin, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], num_partitions: Optional[Int] = None)

Was this page helpful?