Python module
max.pipelines.architectures.unified_dflash_llama3
DFlash speculative decoding for Llama3 with unified graph compilation.
DflashDraftHFConfigβ
class max.pipelines.architectures.unified_dflash_llama3.DflashDraftHFConfig(mask_token_id: 'int', target_layer_ids: 'list[int]', block_size: 'int | None' = None, num_target_layers: 'int | None' = None)
Bases: object
-
Parameters:
block_sizeβ
mask_token_idβ
mask_token_id: int
num_target_layersβ
target_layer_idsβ
PersistentInputBuffersβ
class max.pipelines.architectures.unified_dflash_llama3.PersistentInputBuffers(tokens: 'Buffer', input_row_offsets: 'Buffer')
Bases: object
alloc()β
classmethod alloc(max_batch_size, max_batch_input_tokens, device)
-
Parameters:
-
Return type:
input_row_offsetsβ
input_row_offsets: Buffer
tokensβ
tokens: Buffer
UnifiedDflashLlama3Configβ
class max.pipelines.architectures.unified_dflash_llama3.UnifiedDflashLlama3Config(*, target: 'Llama3Config', draft: 'Llama3Config', speculative_config: 'SpeculativeConfig', target_layer_ids: 'list[int]' = <factory>, mask_token_id: 'int' = 0, block_size: 'int' = 0)
Bases: ArchConfigWithKVCache
-
Parameters:
-
- target (Llama3Config)
- draft (Llama3Config)
- speculative_config (SpeculativeConfig)
- target_layer_ids (list[int])
- mask_token_id (int)
- block_size (int)
block_sizeβ
block_size: int = 0
draftβ
draft: Llama3Config
get_kv_params()β
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
mask_token_idβ
mask_token_id: int = 0
resolve_block_size()β
resolve_block_size(*, default=None)
speculative_configβ
speculative_config: SpeculativeConfig
targetβ
target: Llama3Config
target_layer_idsβ
validate_dflash_fields()β
validate_dflash_fields()
Strict validation run from UnifiedDflashLlama3Model.load_model
once the DFlash-specific fields have been populated from the draft
HF config β __post_init__ accepts the empty-placeholder config
produced by initialize() so we canβt enforce these there.
-
Return type:
-
None
UnifiedDflashLlama3Inputsβ
class max.pipelines.architectures.unified_dflash_llama3.UnifiedDflashLlama3Inputs(tokens, input_row_offsets, return_n_logits, draft_tokens=None, draft_kv_blocks=None, seed=None, temperature=None, top_k=None, max_k=None, top_p=None, min_top_p=None, in_thinking_phase=None, token_bitmasks=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
Inputs for the unified DFlash Llama3 graph.
Carries the buffers consumed by a single execute of the unified graph: the merged tokens / ragged offsets, the draft tokens to verify (None on prefill), the persistent draft KV pool, and the sampling parameters used by the in-graph acceptance sampler.
-
Parameters:
-
- tokens (Buffer)
- input_row_offsets (Buffer)
- return_n_logits (Buffer)
- draft_tokens (Buffer | None)
- draft_kv_blocks (list[Buffer] | None)
- seed (Buffer | None)
- temperature (Buffer | None)
- top_k (Buffer | None)
- max_k (Buffer | None)
- top_p (Buffer | None)
- min_top_p (Buffer | None)
- in_thinking_phase (Buffer | None)
- token_bitmasks (Buffer | None)
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- lora_ids (Buffer | None)
- lora_ranks (Buffer | None)
- hidden_states (Buffer | list[Buffer] | None)
buffersβ
Returns positional Buffer inputs for model ABI calls.
draft_kv_blocksβ
draft_tokensβ
in_thinking_phaseβ
input_row_offsetsβ
input_row_offsets: Buffer
max_kβ
min_top_pβ
return_n_logitsβ
return_n_logits: Buffer
seedβ
temperatureβ
token_bitmasksβ
tokensβ
tokens: Buffer
top_kβ
top_pβ
UnifiedDflashLlama3Modelβ
class max.pipelines.architectures.unified_dflash_llama3.UnifiedDflashLlama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: PipelineModelWithKVCache[TextContext]
Unified DFlash Llama3: target + draft in one compiled graph.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example shows how to implement it for a Mistral model:
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.model.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.model.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e-
Parameters:
-
- pipeline_config (PipelineConfig) β Configuration for the pipeline.
- huggingface_config (AutoConfig) β Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
execute()β
execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) β The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipelineβs output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
get_kv_params()β
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Returns the KV cache params for the pipeline model.
Delegates to model_config_cls.construct_kv_params(...).
Subclasses with custom KV behavior should override this method.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
load_model()β
load_model(session)
-
Parameters:
-
session (InferenceSession)
-
Return type:
modelβ
model: Model
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs to be passed to execute().
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and kv_cache_inputs (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!