Python module
max.pipelines.architectures.dflash_llama3
DFlash draft model for Llama3-family targets.
The draft is a Qwen3-style transformer (per-head Q/K RMSNorm, non-causal
attention) that fuses concatenated target hidden states into its KV cache
via AttentionWithRope.materialize_kv_from_hidden() and runs a single
non-causal block forward over [verified_id, MASK, MASK, …] per
iteration.
DFlashLlama3
class max.pipelines.architectures.dflash_llama3.DFlashLlama3(config, *, num_context_features)
Bases: Module
DFlash draft transformer for a Llama3 target.
-
Parameters:
-
- config (Llama3Config)
- num_context_features (int)
forward_block()
forward_block(input_embeds, kv_collection, input_row_offsets)
-
Parameters:
-
- input_embeds (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- input_row_offsets (TensorValue)
-
Return type:
materialize_kv()
materialize_kv(ctx_hidden, input_row_offsets, kv_collection)
-
Parameters:
-
- ctx_hidden (TensorValue)
- input_row_offsets (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
-
Return type:
-
None
project_target_hidden()
project_target_hidden(target_hs_concat)
-
Parameters:
-
target_hs_concat (TensorValue)
-
Return type:
DFlashLlama3Model
class max.pipelines.architectures.dflash_llama3.DFlashLlama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: LlamaModelBase
Placeholder pipeline model for the DFlash draft architecture.
See module docstring. execute raises because the draft is only
ever run via the unified pipeline.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The configuration for this pipeline.
- session (InferenceSession) – The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
execute()
execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) – The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipeline’s output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!