IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.unified_mtp_gemma4

Gemma4 with MTP draft model for speculative decoding with unified graph compilation.

UnifiedMTPGemma4Inputs​

class max.pipelines.architectures.unified_mtp_gemma4.UnifiedMTPGemma4Inputs(tokens, input_row_offsets, host_input_row_offsets, return_n_logits, data_parallel_splits, signal_buffers, batch_context_lengths, draft_tokens=None, draft_kv_blocks=None, seed=None, temperature=None, top_k=None, max_k=None, top_p=None, min_top_p=None, in_thinking_phase=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

Inputs for the UnifiedMTPGemma4 model.

Parameters:

batch_context_lengths​

batch_context_lengths: list[Buffer]

source

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for model ABI calls.

data_parallel_splits​

data_parallel_splits: Buffer

source

draft_kv_blocks​

draft_kv_blocks: list[Buffer] | None = None

source

draft_tokens​

draft_tokens: Buffer | None = None

source

host_input_row_offsets​

host_input_row_offsets: Buffer

source

in_thinking_phase​

in_thinking_phase: Buffer | None = None

source

Per-batch bool flag marking rows currently inside a <think>...</think> block; consumed by relaxed acceptance.

input_row_offsets​

input_row_offsets: Buffer

source

max_k​

max_k: Buffer | None = None

source

min_top_p​

min_top_p: Buffer | None = None

source

return_n_logits​

return_n_logits: Buffer

source

seed​

seed: Buffer | None = None

source

signal_buffers​

signal_buffers: list[Buffer]

source

temperature​

temperature: Buffer | None = None

source

tokens​

tokens: Buffer

source

top_k​

top_k: Buffer | None = None

source

top_p​

top_p: Buffer | None = None

source

UnifiedMTPGemma4Model​

class max.pipelines.architectures.unified_mtp_gemma4.UnifiedMTPGemma4Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextContext]

Gemma4 with MTP: merge + target + rejection + shift in one graph.

Parameters:

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the optimal max sequence length for the model.

Default implementation delegates to model_config_cls. Override when pipeline-model semantics differ from the config (for example, bounding max_length where the config is permissive).

Parameters:

  • pipeline_config (PipelineConfig) – Configuration for the pipeline.
  • huggingface_config (AutoConfig) – Hugging Face model configuration.

Returns:

The maximum sequence length to use.

Return type:

int

estimate_activation_memory()​

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Estimates the activation memory required for model execution.

This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.

The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.

Parameters:

  • pipeline_config (PipelineConfig) – Pipeline configuration
  • huggingface_config (AutoConfig) – Hugging Face model configuration

Returns:

Estimated activation memory in bytes

Return type:

int

execute()​

execute(model_inputs)

source

Execute and return all 3 graph outputs for speculative decoding.

Parameters:

model_inputs (ModelInputs)

Return type:

UnifiedEagleOutputs

load_model()​

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model​

model: Model

source

model_config_cls​

model_config_cls

source

alias of Gemma4ForConditionalGenerationConfig

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1, draft_tokens=None, draft_kv_cache_buffers=None, **kwargs)

source

Prepares the initial inputs to be passed to execute().

The inputs and functionality can vary per model. For example, model inputs could include encoded tensors, unique IDs per tensor when using a KV cache manager, and kv_cache_inputs (or None if the model does not use KV cache). This method typically batches encoded tensors, claims a KV cache slot if needed, and returns the inputs and caches.

Parameters:

Return type:

UnifiedMTPGemma4Inputs

prepare_next_token_inputs()​

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Parameters:

Return type:

UnifiedMTPGemma4Inputs