IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.unified_eagle_llama3

EAGLE speculative decoding draft model for Llama 3 with unified graph compilation.

PersistentInputBuffers​

class max.pipelines.architectures.unified_eagle_llama3.PersistentInputBuffers(tokens: 'Buffer', input_row_offsets: 'Buffer')

source

Bases: object

Parameters:

alloc()​

classmethod alloc(max_batch_size, max_batch_input_tokens, device)

source

Parameters:

  • max_batch_size (int)
  • max_batch_input_tokens (int)
  • device (Device)

Return type:

PersistentInputBuffers

input_row_offsets​

input_row_offsets: Buffer

source

tokens​

tokens: Buffer

source

UnifiedEagleLlama3Config​

class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Config(*, target: 'Llama3Config', draft: 'Llama3Config', speculative_config: 'SpeculativeConfig', enable_structured_output: 'bool' = False)

source

Bases: ArchConfigWithKVCache

Parameters:

draft​

draft: Llama3Config

source

enable_structured_output​

enable_structured_output: bool = False

source

When True, the graph accepts a bitmask input for grammar-constrained decoding.

get_kv_params()​

get_kv_params()

source

KV cache parameters to use when running the model.

Return type:

KVCacheParamInterface

get_max_seq_len()​

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

speculative_config​

speculative_config: SpeculativeConfig

source

target​

target: Llama3Config

source

UnifiedEagleLlama3Inputs​

class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Inputs(tokens, input_row_offsets, return_n_logits, draft_tokens=None, draft_kv_blocks=None, seed=None, temperature=None, top_k=None, max_k=None, top_p=None, min_top_p=None, in_thinking_phase=None, pinned_bitmask=None, wait_payload=None, device_bitmask_scratch=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

Inputs for the unified EAGLE Llama3 model.

Parameters:

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for model ABI calls.

device_bitmask_scratch​

device_bitmask_scratch: Buffer | None = None

source

Device scratch buffer that receives the in-graph H2D from pinned_bitmask; the acceptance sampler reads from it. Only set when structured output is enabled.

draft_kv_blocks​

draft_kv_blocks: list[Buffer] | None = None

source

draft_tokens​

draft_tokens: Buffer | None = None

source

in_thinking_phase​

in_thinking_phase: Buffer | None = None

source

Per-batch bool flag set by the pipeline for relaxed acceptance during thinking. Not consumed by the unified_eagle_llama3 graph today, but the field is required to satisfy the _UnifiedSpecDecodeInputs protocol used by OverlapTextGenerationPipeline.

input_row_offsets​

input_row_offsets: Buffer

source

max_k​

max_k: Buffer | None = None

source

min_top_p​

min_top_p: Buffer | None = None

source

Per-batch sampling parameters consumed by the stochastic acceptance sampler. max_k and min_top_p are 0-d CPU scalars; the rest are [batch_size] tensors on the primary device.

pinned_bitmask​

pinned_bitmask: Buffer | None = None

source

Pinned host bitmask for constrained decoding.

Shape [batch_size, num_speculative_tokens + 1, vocab_size]. Position i contains the valid-token mask given the FSM state after consuming draft[0:i-1]; position num_speculative_tokens is for the bonus token. None when structured output is disabled.

return_n_logits​

return_n_logits: Buffer

source

seed​

seed: Buffer | None = None

source

temperature​

temperature: Buffer | None = None

source

tokens​

tokens: Buffer

source

top_k​

top_k: Buffer | None = None

source

top_p​

top_p: Buffer | None = None

source

wait_payload​

wait_payload: Buffer | None = None

source

CPU int64[2] payload = [flag._unsafe_ptr, 1] consumed by the in-graph mo.wait_host_value_with_dep op. Only set when structured output is enabled.

UnifiedEagleLlama3Model​

class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: PipelineModelWithKVCache[TextContext]

Unified EAGLE Llama3: target + draft in one compiled graph.

Parameters:

execute()​

execute(model_inputs)

source

Execute and return all graph outputs for speculative decoding.

Parameters:

model_inputs (ModelInputs)

Return type:

UnifiedEagleOutputs

load_model()​

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model​

model: Model

source

model_config_cls​

model_config_cls

source

alias of Llama3Config

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepares the initial inputs to be passed to execute().

The inputs and functionality can vary per model. For example, model inputs could include encoded tensors, unique IDs per tensor when using a KV cache manager, and kv_cache_inputs (or None if the model does not use KV cache). This method typically batches encoded tensors, claims a KV cache slot if needed, and returns the inputs and caches.

Parameters:

Return type:

UnifiedEagleLlama3Inputs