Skip to main content

Python module

max.pipelines.architectures.gemma4

Gemma 4 vision-language architecture for multimodal text generation.

Gemma3_MultiModalModel

class max.pipelines.architectures.gemma4.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[Gemma4Context]

Gemma 3 multimodal pipeline model for text generation.

This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

calculate_max_seq_len()

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the InternVL model.

Parameters:

Return type:

int

estimate_activation_memory()

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Estimates the activation memory required for model execution.

This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.

The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.

Parameters:

  • pipeline_config (PipelineConfig) – Pipeline configuration
  • huggingface_config (AutoConfig) – Hugging Face model configuration

Returns:

Estimated activation memory in bytes

Return type:

int

execute()

execute(model_inputs)

source

Execute the vision model (if needed), then the language model.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

get_kv_params()

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Gets the parameters required to configure the KV cache for InternVL.

Parameters:

Return type:

MultiKVCacheParams

language_model

language_model: Model

source

The compiled and initialized MAX Engine model ready for inference.

load_model()

load_model(session)

source

Loads the compiled Gemma3 MultiModal models into the MAX Engine session.

Returns:

A tuple of (vision_model, language_model).

Parameters:

session (InferenceSession)

Return type:

tuple[Model, Model]

model

property model: Model

source

Expose language model for graph capture/replay.

Only the language model is captured since vision runs during prefill only.

prepare_initial_token_inputs()

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepare inputs for the first execution pass.

Parameters:

Return type:

ModelInputs

prepare_next_token_inputs()

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the secondary inputs to be passed to execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

Parameters:

Return type:

ModelInputs

release()

release(request_id)

source

Release vision encoder cache for a completed request.

Parameters:

request_id (RequestID)

Return type:

None

vision_model

vision_model: Model

source

The compiled and initialized MAX Engine vision model ready for inference.

Gemma4ForConditionalGenerationConfig

class max.pipelines.architectures.gemma4.Gemma4ForConditionalGenerationConfig(*, devices, dtype, kv_params, image_token_index, video_token_index=262144, text_config, vision_config, tie_word_embeddings=False)

source

Bases: ArchConfigWithKVAndVisionCache

Base configuration for Gemma 4 multimodal models.

This is the top-level config that composes text and vision sub-configs. Model-specific parameters live in the respective sub-configs.

Parameters:

calculate_max_seq_len()

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the Gemma 4 model.

Parameters:

Return type:

int

construct_kv_params()

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs KV cache parameters from the top-level HuggingFace config.

Parameters:

  • huggingface_config (AutoConfig) – The top-level HuggingFace config (with text_config).
  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • devices (list[DeviceRef]) – Target devices for the model.
  • kv_cache_config (KVCacheConfig) – KV cache configuration settings.
  • cache_dtype (DType) – Data type for the KV cache.

Returns:

Configured KV cache parameters.

Return type:

MultiKVCacheParams

devices

devices: list[DeviceRef]

source

Devices to run the model with.

dtype

dtype: DType

source

DType of the model weights and input.

estimate_vision_cache_entry_bytes()

static estimate_vision_cache_entry_bytes(huggingface_config)

source

Estimate per-entry bytes for the vision encoder cache.

Worst-case tokens per image is position_embedding_size / pooling_kernel_size², stored at the text hidden size in bfloat16.

Parameters:

huggingface_config (AutoConfig)

Return type:

int

finalize()

finalize(huggingface_config, state_dict, return_logits)

source

Finalize with state_dict-dependent fields.

Parses quantization config from the weights and finalizes the text sub-config.

Parameters:

  • huggingface_config (AutoConfig) – HuggingFace model configuration.
  • state_dict (dict[str, WeightData]) – Model weights dictionary.
  • return_logits (ReturnLogits) – Return logits configuration.

Return type:

None

get_kv_params()

get_kv_params()

source

Returns the KV cache parameters.

Return type:

MultiKVCacheParams

get_max_seq_len()

get_max_seq_len()

source

Returns the maximum sequence length from the embedded text config.

Return type:

int

image_token_index

image_token_index: int

source

The image token index to encode the image prompt.

initialize()

classmethod initialize(pipeline_config, model_config=None)

source

Initializes from pipeline configuration.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • model_config (MAXModelConfig | None) – Optional model config override.

Returns:

An initialized config instance.

Return type:

Self

initialize_from_config()

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes from pipeline and HuggingFace configs.

Fields that depend on the state_dict should be set via finalize().

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – Top-level HuggingFace model configuration.

Returns:

A config instance ready for finalization.

Return type:

Self

kv_params

kv_params: MultiKVCacheParams

source

KV cache parameters.

text_config

text_config: Gemma4TextConfig

source

The config object of the text backbone.

tie_word_embeddings

tie_word_embeddings: bool = False

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

video_token_index

video_token_index: int = 262144

source

The video token index to encode the video prompt.

vision_config

vision_config: Gemma4VisionConfig

source

The config object of the vision encoder.