Skip to main content

Python module

max.pipelines.architectures.unified_eagle_llama3

EAGLE speculative decoding draft model for Llama 3 with unified graph compilation.

PersistentInputBuffers​

class max.pipelines.architectures.unified_eagle_llama3.PersistentInputBuffers(tokens: 'Buffer', input_row_offsets: 'Buffer')

source

Bases: object

Parameters:

alloc()​

classmethod alloc(max_batch_size, max_batch_input_tokens, device)

source

Parameters:

  • max_batch_size (int)
  • max_batch_input_tokens (int)
  • device (Device)

Return type:

PersistentInputBuffers

input_row_offsets​

input_row_offsets: Buffer

source

tokens​

tokens: Buffer

source

UnifiedEagleLlama3Config​

class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Config(*, target: 'Llama3Config', draft: 'Llama3Config', speculative_config: 'SpeculativeConfig')

source

Bases: ArchConfigWithKVCache

Parameters:

draft​

draft: Llama3Config

source

get_kv_params()​

get_kv_params()

source

KV cache parameters to use when running the model.

Return type:

KVCacheParamInterface

get_max_seq_len()​

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

speculative_config​

speculative_config: SpeculativeConfig

source

target​

target: Llama3Config

source

UnifiedEagleLlama3Inputs​

class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Inputs(tokens, input_row_offsets, return_n_logits, draft_tokens=None, draft_kv_blocks=None, seed=None, temperature=None, top_k=None, max_k=None, top_p=None, min_top_p=None, in_thinking_phase=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

Inputs for the unified EAGLE Llama3 model.

Parameters:

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for model ABI calls.

draft_kv_blocks​

draft_kv_blocks: list[Buffer] | None = None

source

draft_tokens​

draft_tokens: Buffer | None = None

source

in_thinking_phase​

in_thinking_phase: Buffer | None = None

source

Per-batch bool flag set by the pipeline for relaxed acceptance during thinking. Not consumed by the unified_eagle_llama3 graph today, but the field is required to satisfy the _UnifiedEagleInputs protocol used by OverlapTextGenerationPipeline.

input_row_offsets​

input_row_offsets: Buffer

source

max_k​

max_k: Buffer | None = None

source

min_top_p​

min_top_p: Buffer | None = None

source

Per-batch sampling parameters consumed by the stochastic acceptance sampler. max_k and min_top_p are 0-d CPU scalars; the rest are [batch_size] tensors on the primary device.

return_n_logits​

return_n_logits: Buffer

source

seed​

seed: Buffer | None = None

source

Per-execute int64 scalar seed consumed by the stochastic acceptance sampler (and, when enabled, the synthetic benchmarking sampler).

temperature​

temperature: Buffer | None = None

source

tokens​

tokens: Buffer

source

top_k​

top_k: Buffer | None = None

source

top_p​

top_p: Buffer | None = None

source

UnifiedEagleLlama3Model​

class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: PipelineModelWithKVCache[TextContext]

Unified EAGLE Llama3: target + draft in one compiled graph.

Parameters:

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the optimal max sequence length for the model.

Models are expected to implement this method. The following example shows how to implement it for a Mistral model:

class MistralModel(PipelineModel):
    @classmethod
    def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
        try:
            return upper_bounded_default(
                upper_bound=huggingface_config.max_seq_len,
                default=pipeline_config.model.max_length,
            )
        except ValueError as e:
            raise ValueError(
                "Unable to infer max_length for Mistral, the provided "
                f"max_length ({pipeline_config.model.max_length}) exceeds the "
                f"model's max_seq_len ({huggingface_config.max_seq_len})."
            ) from e

Parameters:

  • pipeline_config (PipelineConfig) – Configuration for the pipeline.
  • huggingface_config (AutoConfig) – Hugging Face model configuration.

Returns:

The maximum sequence length to use.

Return type:

int

execute()​

execute(model_inputs)

source

Execute and return all graph outputs for speculative decoding.

Parameters:

model_inputs (ModelInputs)

Return type:

UnifiedEagleOutputs

get_kv_params()​

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Returns the KV cache params for the pipeline model.

Parameters:

Return type:

KVCacheParams

load_model()​

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model​

model: Model

source

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepares the initial inputs to be passed to execute().

The inputs and functionality can vary per model. For example, model inputs could include encoded tensors, unique IDs per tensor when using a KV cache manager, and kv_cache_inputs (or None if the model does not use KV cache). This method typically batches encoded tensors, claims a KV cache slot if needed, and returns the inputs and caches.

Parameters:

Return type:

UnifiedEagleLlama3Inputs

prepare_next_token_inputs()​

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the secondary inputs to be passed to execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

Parameters:

Return type:

UnifiedEagleLlama3Inputs