IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.unified_dflash_llama3

DFlash speculative decoding for Llama3 with unified graph compilation.

DflashDraftHFConfig​

class max.pipelines.architectures.unified_dflash_llama3.DflashDraftHFConfig(mask_token_id: 'int', target_layer_ids: 'list[int]', block_size: 'int | None' = None, num_target_layers: 'int | None' = None)

source

Bases: object

Parameters:

  • mask_token_id (int)
  • target_layer_ids (list[int])
  • block_size (int | None)
  • num_target_layers (int | None)

block_size​

block_size: int | None = None

source

mask_token_id​

mask_token_id: int

source

num_target_layers​

num_target_layers: int | None = None

source

target_layer_ids​

target_layer_ids: list[int]

source

PersistentInputBuffers​

class max.pipelines.architectures.unified_dflash_llama3.PersistentInputBuffers(tokens, input_row_offsets)

source

Bases: object

Pinned-host buffers reused across unified spec-decode batch steps.

Parameters:

alloc()​

classmethod alloc(max_batch_size, max_batch_input_tokens, device)

source

Allocates persistent token and row-offset buffers for spec-decode batching.

Parameters:

  • max_batch_size (int)
  • max_batch_input_tokens (int)
  • device (Device)

Return type:

PersistentInputBuffers

input_row_offsets​

input_row_offsets: Buffer

source

tokens​

tokens: Buffer

source

UnifiedDflashLlama3Config​

class max.pipelines.architectures.unified_dflash_llama3.UnifiedDflashLlama3Config(*, target: 'Llama3Config', draft: 'Llama3Config', speculative_config: 'SpeculativeConfig', target_layer_ids: 'list[int]' = <factory>, mask_token_id: 'int' = 0, block_size: 'int' = 0)

source

Bases: ArchConfigWithKVCache

Parameters:

block_size​

block_size: int = 0

source

draft​

draft: Llama3Config

source

get_kv_params()​

get_kv_params()

source

KV cache parameters to use when running the model.

Return type:

KVCacheParamInterface

get_max_seq_len()​

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

mask_token_id​

mask_token_id: int = 0

source

resolve_block_size()​

resolve_block_size(*, default=None)

source

Parameters:

default (int | None)

Return type:

int

speculative_config​

speculative_config: SpeculativeConfig

source

target​

target: Llama3Config

source

target_layer_ids​

target_layer_ids: list[int]

source

validate_dflash_fields()​

validate_dflash_fields()

source

Strict validation run from UnifiedDflashLlama3Model.load_model once the DFlash-specific fields have been populated from the draft HF config β€” __post_init__ accepts the empty-placeholder config produced by initialize() so we can’t enforce these there.

Return type:

None

UnifiedDflashLlama3Inputs​

class max.pipelines.architectures.unified_dflash_llama3.UnifiedDflashLlama3Inputs(tokens, input_row_offsets, return_n_logits, *, kv_cache_inputs=None, lora=None, hidden_states=None, draft_tokens=None, seed=None, temperature=None, top_k=None, max_k=None, top_p=None, min_top_p=None, in_thinking_phase=None, pinned_bitmask=None, wait_payload=None, device_bitmask_scratch=None, structured_output=False)

source

Bases: UnifiedSpecDecodeInputs

Inputs for the unified DFlash Llama3 graph.

The spec-decode fields and trailing buffer packing come from UnifiedSpecDecodeInputs; tokens / input_row_offsets / return_n_logits plus the KV cache form this single-device graph’s prefix. The DFlash graph does not bind in_thinking_phase.

Parameters:

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for model ABI calls.

input_row_offsets​

input_row_offsets: Buffer

source

return_n_logits​

return_n_logits: Buffer

source

tokens​

tokens: Buffer

source

UnifiedDflashLlama3Model​

class max.pipelines.architectures.unified_dflash_llama3.UnifiedDflashLlama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, max_batch_size=1)

source

Bases: _UnifiedSpecDecodeModelMixin, PipelineModelWithKVCache[TextContext]

Unified DFlash Llama3: target + draft in one compiled graph.

Parameters:

batch_processor_cls​

batch_processor_cls

source

alias of UnifiedDflashLlama3BatchProcessor

get_kv_params()​

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Returns the KV cache params for the pipeline model.

Delegates to model_config_cls.construct_kv_params(...). Subclasses with custom KV behavior should override this method.

Parameters:

Return type:

KVCacheParams

load_model()​

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model​

model: Model

source

model_config_cls​

model_config_cls

source

alias of UnifiedDflashLlama3Config