IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

flare_mla_decoding

def flare_mla_decoding[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, //, config: MHAConfig[dtype], ragged: Bool = False, decoding_warp_split_k: Bool = False, per_token_scale_rope_aware: Bool = False, sparse: Bool = False, rope_aware_kv_sparse: Bool = False](output: TileTensor[Storage=output.Storage, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, Storage=q.Storage, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: cache_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, Storage=valid_length.Storage, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], scale: Float32, ctx: DeviceContext, scalar_args_buf: NullableTileTensor[DType.int64, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], q_max_seq_len: OptionalReg[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(Int(-1)), ImmutAnyOrigin]] = None, num_partitions: Optional[Int] = None, q_scale_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, indices_stride: Int = Int(0), topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, extra_k: OptionalReg[cache_t] = None, extra_d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_indices_stride: Int = Int(0), extra_topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_scales_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, num_partitions_in: Optional[Int] = None)

MLA decoding kernel that would only be called in the optimized compute graph.

The Q input has a shape of [seq_len, num_heads, depth]. The K input has a shape of [seq_len, 1, depth]. The V tensor is derived by reusing K, where V = K[:, :, :depth_v].

Specifically, for DeepSeek V2/3, depth = 576 and depth_v = 512.

When per_token_scale_rope_aware is True, Q and KV cache have an interleaved FP8+BF16 layout: FP8 content (512 bytes) + BF16 rope (128 bytes) = 640 bytes/row. Q's last dimension is 640 (FP8 elements) but represents 576 logical dimensions (512 nope + 64 rope).

This kernel computes attention without needing to load V twice. This kernel only handles decoding requests. In this case q_max_seq_len = 1.

This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.

def flare_mla_decoding[mask_t: MHAMask, dtype: DType, //, config: MHAConfig[dtype], decoding_warp_split_k: Bool = False](output: TileTensor[Storage=output.Storage, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, Storage=q.Storage, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[Storage=k.Storage, linear_idx_type=k.linear_idx_type, element_size=k.element_size], mask_functor: mask_t, scale: Float32, ctx: DeviceContext, scalar_args_buf: NullableTileTensor[DType.int64, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], num_partitions: Optional[Int] = None)