For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python module
max.pipelines.architectures.unified_eagle_llama3
EAGLE speculative decoding draft model for Llama 3 with unified graph compilation.
PersistentInputBuffersβ
class max.pipelines.architectures.unified_eagle_llama3.PersistentInputBuffers(tokens: 'Buffer', input_row_offsets: 'Buffer')
Bases: object
alloc()β
classmethod alloc(max_batch_size, max_batch_input_tokens, device)
-
Parameters:
-
Return type:
input_row_offsetsβ
input_row_offsets: Buffer
tokensβ
tokens: Buffer
UnifiedEagleLlama3Configβ
class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Config(*, target: 'Llama3Config', draft: 'Llama3Config', speculative_config: 'SpeculativeConfig', enable_structured_output: 'bool' = False)
Bases: ArchConfigWithKVCache
-
Parameters:
-
- target (Llama3Config)
- draft (Llama3Config)
- speculative_config (SpeculativeConfig)
- enable_structured_output (bool)
draftβ
draft: Llama3Config
enable_structured_outputβ
enable_structured_output: bool = False
When True, the graph accepts a bitmask input for grammar-constrained decoding.
get_kv_params()β
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
speculative_configβ
speculative_config: SpeculativeConfig
targetβ
target: Llama3Config
UnifiedEagleLlama3Inputsβ
class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Inputs(tokens, input_row_offsets, return_n_logits, draft_tokens=None, draft_kv_blocks=None, seed=None, temperature=None, top_k=None, max_k=None, top_p=None, min_top_p=None, in_thinking_phase=None, pinned_bitmask=None, wait_payload=None, device_bitmask_scratch=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
Inputs for the unified EAGLE Llama3 model.
-
Parameters:
-
- tokens (Buffer)
- input_row_offsets (Buffer)
- return_n_logits (Buffer)
- draft_tokens (Buffer | None)
- draft_kv_blocks (list[Buffer] | None)
- seed (Buffer | None)
- temperature (Buffer | None)
- top_k (Buffer | None)
- max_k (Buffer | None)
- top_p (Buffer | None)
- min_top_p (Buffer | None)
- in_thinking_phase (Buffer | None)
- pinned_bitmask (Buffer | None)
- wait_payload (Buffer | None)
- device_bitmask_scratch (Buffer | None)
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- lora_ids (Buffer | None)
- lora_ranks (Buffer | None)
- hidden_states (Buffer | list[Buffer] | None)
buffersβ
Returns positional Buffer inputs for model ABI calls.
device_bitmask_scratchβ
Device scratch buffer that receives the in-graph H2D from
pinned_bitmask; the acceptance sampler reads from it. Only set
when structured output is enabled.
draft_kv_blocksβ
draft_tokensβ
in_thinking_phaseβ
Per-batch bool flag set by the pipeline for relaxed acceptance
during thinking. Not consumed by the unified_eagle_llama3 graph today,
but the field is required to satisfy the _UnifiedSpecDecodeInputs protocol
used by OverlapTextGenerationPipeline.
input_row_offsetsβ
input_row_offsets: Buffer
max_kβ
min_top_pβ
Per-batch sampling parameters consumed by the stochastic acceptance
sampler. max_k and min_top_p are 0-d CPU scalars; the rest are
[batch_size] tensors on the primary device.
pinned_bitmaskβ
Pinned host bitmask for constrained decoding.
Shape [batch_size, num_speculative_tokens + 1, vocab_size].
Position i contains the valid-token mask given the FSM state after
consuming draft[0:i-1]; position num_speculative_tokens is for
the bonus token. None when structured output is disabled.
return_n_logitsβ
return_n_logits: Buffer
seedβ
temperatureβ
tokensβ
tokens: Buffer
top_kβ
top_pβ
wait_payloadβ
CPU int64[2] payload = [flag._unsafe_ptr, 1] consumed by
the in-graph mo.wait_host_value_with_dep op. Only set when
structured output is enabled.
UnifiedEagleLlama3Modelβ
class max.pipelines.architectures.unified_eagle_llama3.UnifiedEagleLlama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: PipelineModelWithKVCache[TextContext]
Unified EAGLE Llama3: target + draft in one compiled graph.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
execute()β
execute(model_inputs)
Execute and return all graph outputs for speculative decoding.
-
Parameters:
-
model_inputs (ModelInputs)
-
Return type:
load_model()β
load_model(session)
-
Parameters:
-
session (InferenceSession)
-
Return type:
modelβ
model: Model
model_config_clsβ
model_config_cls
alias of Llama3Config
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs to be passed to execute().
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and kv_cache_inputs (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!