IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

AttentionDispatchResolver

AttentionDispatchResolverโ€‹

class max.nn.kv_cache.AttentionDispatchResolver(devices, is_mla, n_kv_heads_per_device, num_q_heads_per_device=None, is_fp8_kv=False)

source

Bases: object

Resolves packed attention decode metadata via kernel custom ops.

Supports both MHA (mo.mha.decode.get_num_partitions) and MLA (mo.mla.compute_dispatch_args.scalar) decode kernels, selected from the is_mla flag.

Parameters:

  • devices (Sequence[DeviceRef])
  • is_mla (bool)
  • n_kv_heads_per_device (int)
  • num_q_heads_per_device (int | None)
  • is_fp8_kv (bool)

probe_lengths()โ€‹

probe_lengths(max_cache_length, q_max_seq_len=1)

source

Returns cache lengths to probe for distinct num_partitions.

These are the cache lengths warmed up during graph capture. MHA probes at 256-token granularity; MLA probes at a finer 64-token granularity (and, under speculative decoding, adds extra probes to hit more (num_partitions, draft_num_partitions) pairs). The selected granularity follows is_mla.

Parameters:

  • max_cache_length (int)
  • q_max_seq_len (int)

Return type:

list[int]

resolve_attn_key()โ€‹

resolve_attn_key(batch_size, max_prompt_length, max_cache_valid_length)

source

Returns the resolved decode dispatch key for the given shape.

Empty / degenerate replicas (batch_size <= 0 or a CPU-only resolver) return a sentinel key (num_partitions=1) without invoking the dispatch kernels.

Parameters:

  • batch_size (int)
  • max_prompt_length (int)
  • max_cache_valid_length (int)

Return type:

AttnKey