Python module
max.pipelines.architectures.eagle_llama3
EAGLE speculative decoding draft model for Llama 3.
EagleLlama3Model
class max.pipelines.architectures.eagle_llama3.EagleLlama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.LAST)
Bases: LlamaModelBase
EAGLE Llama3 draft model pipeline implementation.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The configuration for this pipeline.
- session (InferenceSession) – The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
execute()
execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) – The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipeline’s output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
norm_method
norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'rms_norm'
Normalization layer.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!