Skip to main content

Python class

AttentionDispatchResolver

AttentionDispatchResolver

class max.nn.kv_cache.AttentionDispatchResolver(devices, is_mla, n_kv_heads_per_device, num_q_heads_per_device=None, is_fp8_kv=False)

source

Bases: object

Resolves packed attention decode metadata via kernel custom ops.

Supports both MHA (mo.mha.decode.get_num_partitions) and MLA (mo.mla.compute_dispatch_args.scalar) decode kernels. The mode is selected automatically from kv_params.is_mla.

Parameters:

resolve_for_replica()

resolve_for_replica(batch_size, max_prompt_length, max_cache_valid_length)

source

Returns one dispatch-metadata buffer per shard in a replica.

Parameters:

  • batch_size (int)
  • max_prompt_length (int)
  • max_cache_valid_length (int)

Return type:

list[Buffer]