For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python module
max.pipelines.architectures.unified_mtp_gemma4
Gemma4 with MTP draft model for speculative decoding with unified graph compilation.
UnifiedMTPGemma4Inputsβ
class max.pipelines.architectures.unified_mtp_gemma4.UnifiedMTPGemma4Inputs(tokens, input_row_offsets, host_input_row_offsets, return_n_logits, data_parallel_splits, signal_buffers, batch_context_lengths, draft_tokens=None, draft_kv_blocks=None, seed=None, temperature=None, top_k=None, max_k=None, top_p=None, min_top_p=None, in_thinking_phase=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
Inputs for the UnifiedMTPGemma4 model.
-
Parameters:
-
- tokens (Buffer)
- input_row_offsets (Buffer)
- host_input_row_offsets (Buffer)
- return_n_logits (Buffer)
- data_parallel_splits (Buffer)
- signal_buffers (list[Buffer])
- batch_context_lengths (list[Buffer])
- draft_tokens (Buffer | None)
- draft_kv_blocks (list[Buffer] | None)
- seed (Buffer | None)
- temperature (Buffer | None)
- top_k (Buffer | None)
- max_k (Buffer | None)
- top_p (Buffer | None)
- min_top_p (Buffer | None)
- in_thinking_phase (Buffer | None)
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- lora_ids (Buffer | None)
- lora_ranks (Buffer | None)
- hidden_states (Buffer | list[Buffer] | None)
batch_context_lengthsβ
buffersβ
Returns positional Buffer inputs for model ABI calls.
data_parallel_splitsβ
data_parallel_splits: Buffer
draft_kv_blocksβ
draft_tokensβ
host_input_row_offsetsβ
host_input_row_offsets: Buffer
in_thinking_phaseβ
Per-batch bool flag marking rows currently inside a
<think>...</think> block; consumed by relaxed acceptance.
input_row_offsetsβ
input_row_offsets: Buffer
max_kβ
min_top_pβ
return_n_logitsβ
return_n_logits: Buffer
seedβ
signal_buffersβ
temperatureβ
tokensβ
tokens: Buffer
top_kβ
top_pβ
UnifiedMTPGemma4Modelβ
class max.pipelines.architectures.unified_mtp_gemma4.UnifiedMTPGemma4Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextContext]
Gemma4 with MTP: merge + target + rejection + shift in one graph.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Default implementation delegates to model_config_cls. Override when
pipeline-model semantics differ from the config (for example, bounding
max_length where the config is permissive).
-
Parameters:
-
- pipeline_config (PipelineConfig) β Configuration for the pipeline.
- huggingface_config (AutoConfig) β Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
estimate_activation_memory()β
classmethod estimate_activation_memory(pipeline_config, huggingface_config)
Estimates the activation memory required for model execution.
This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.
The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.
-
Parameters:
-
- pipeline_config (PipelineConfig) β Pipeline configuration
- huggingface_config (AutoConfig) β Hugging Face model configuration
-
Returns:
-
Estimated activation memory in bytes
-
Return type:
execute()β
execute(model_inputs)
Execute and return all 3 graph outputs for speculative decoding.
-
Parameters:
-
model_inputs (ModelInputs)
-
Return type:
load_model()β
load_model(session)
-
Parameters:
-
session (InferenceSession)
-
Return type:
modelβ
model: Model
model_config_clsβ
model_config_cls
alias of Gemma4ForConditionalGenerationConfig
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1, draft_tokens=None, draft_kv_cache_buffers=None, **kwargs)
Prepares the initial inputs to be passed to execute().
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and kv_cache_inputs (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
-
Parameters:
-
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!