Skip to main content

Python module

max.pipelines.architectures.gemma3multimodal

Gemma 3 vision-language architecture for multimodal text generation.

Gemma3ForConditionalGenerationConfig​

class max.pipelines.architectures.gemma3multimodal.Gemma3ForConditionalGenerationConfig(*, boi_token_index, eoi_token_index, devices, dtype, kv_params, image_token_index, initializer_range, interleaved_rope_weights, mm_tokens_per_image, return_logits, tie_word_embeddings, text_config, vision_config, attention_bias=False, quant_config=None, head_dim=256, num_key_value_heads=4)

source

Bases: ArchConfigWithKVCache

Base configuration for Gemma 3 models.

Contains parameters specific to the Gemma 3 architecture, typically extracted from a HuggingFace configuration object’s text config.

Parameters:

attention_bias​

attention_bias: bool = False

source

Whether to use a bias in the query, key, value and output projection layers during self-attention.

boi_token_index​

boi_token_index: int

source

The begin-of-image token index to wrap the image prompt

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Parameters:

Return type:

int

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Parameters:

Return type:

KVCacheParams

devices​

devices: list[DeviceRef]

source

Devices to run the model with.

dtype​

dtype: DType

source

DType of the model weights and input.

eoi_token_index​

eoi_token_index: int

source

The end-of-image token index to wrap the image prompt

finalize()​

finalize(huggingface_config, state_dict, return_logits)

source

Finalize the Gemma3ForConditionalGenerationConfig instance with state_dict dependent fields.

Parameters:

  • huggingface_config (AutoConfig) – HuggingFace model configuration.
  • state_dict (dict[str, WeightData]) – Model weights dictionary.
  • return_logits (ReturnLogits) – Return logits configuration.

Return type:

None

get_kv_params()​

get_kv_params()

source

Returns the KV cache parameters.

Return type:

KVCacheParams

get_max_seq_len()​

get_max_seq_len()

source

Returns the maximum sequence length from the embedded text config.

Return type:

int

get_num_layers()​

static get_num_layers(huggingface_config)

source

Parameters:

huggingface_config (AutoConfig)

Return type:

int

head_dim​

head_dim: int = 256

source

The attention head dimension.

image_token_index​

image_token_index: int

source

The image token index to encode the image prompt

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initializes a Gemma3ForConditionalGenerationConfig instance from pipeline configuration.

Parameters:

Returns:

A Gemma3ForConditionalGenerationConfig instance with fields initialized from config.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes a Gemma3ForConditionalGenerationConfig from pipeline and HuggingFace configs.

This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configurations, without needing the state_dict. Fields that depend on the state_dict should be set via the finalize() method.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – HuggingFace model configuration.

Returns:

A Gemma3ForConditionalGenerationConfig instance ready for finalization.

Return type:

Self

initializer_range​

initializer_range: float

source

Standard deviation for weight initialization.

interleaved_rope_weights​

interleaved_rope_weights: bool

source

True if the rope weights are in interleaved complex format.

kv_params​

kv_params: KVCacheParams

source

KV cache parameters.

mm_tokens_per_image​

mm_tokens_per_image: int

source

The number of tokens per image embedding

num_key_value_heads​

num_key_value_heads: int = 4

source

This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed”

quant_config​

quant_config: QuantConfig | None = None

source

Scaled quantization configuration.

return_logits​

return_logits: ReturnLogits

source

Whether to return the last token, all logits, or a variable number of logits.

text_config​

text_config: Gemma3Config

source

The config object of the text backbone

tie_word_embeddings​

tie_word_embeddings: bool

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

vision_config​

vision_config: Gemma3VisionConfig

source

Custom vision config or dict

Gemma3MultiModalModelInputs​

class max.pipelines.architectures.gemma3multimodal.Gemma3MultiModalModelInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, pixel_values=None, image_token_indices=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Gemma3 multi modal model.

This class encapsulates the input tensors required for the Gemma3 multi modal model, for text and vision processing.

Parameters:

has_vision_inputs​

property has_vision_inputs: bool

source

Check if this input contains vision data.

image_token_indices​

image_token_indices: list[Buffer] | None = None

source

Pre-computed indices of image tokens in the input sequence.

input_row_offsets​

input_row_offsets: ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]

source

Tensor containing the offsets for each row in the ragged input sequence, or the attention mask for the padded input sequence. For distributed execution, this can be a list of tensors, one per device.

pixel_values​

pixel_values: list[Buffer] | None = None

source

[batch, channels, height, width].

Type:

Raw pixel values for vision inputs

return_n_logits​

return_n_logits: Buffer

source

Number of logits to return, used by speculative decoding for example.

signal_buffers​

signal_buffers: list[Buffer]

source

Device buffers used for synchronization in communication collectives.

tokens​

tokens: ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer

source

Tensor containing the input token IDs.

Gemma3VisionConfig​

class max.pipelines.architectures.gemma3multimodal.Gemma3VisionConfig(hidden_act, hidden_size, image_size, intermediate_size, layer_norm_eps, num_attention_heads, num_hidden_layers, num_channels, patch_size, attention_bias=True, attention_dropout=0.0, vision_use_head=False)

source

Bases: object

The vision-specific config for Gemma3 More info at: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json

Parameters:

  • hidden_act (str)
  • hidden_size (int)
  • image_size (int)
  • intermediate_size (int)
  • layer_norm_eps (float)
  • num_attention_heads (int)
  • num_hidden_layers (int)
  • num_channels (int)
  • patch_size (int)
  • attention_bias (bool)
  • attention_dropout (float)
  • vision_use_head (bool)

attention_bias​

attention_bias: bool = True

source

attention_dropout​

attention_dropout: float = 0.0

source

The dropout ratio for the attention probabilities

hidden_act​

hidden_act: str

source

The non-linear activation function (function or string) in the encoder and pooler. β€œgelu”, β€œgelu_tanh”, β€œrelu”, β€œsigmoid”, β€œsilu”, and β€œtanh” are supported.

hidden_size​

hidden_size: int

source

Dimensionality of the encoder layers and the pooler layer

image_size​

image_size: int

source

The size (resolution) of each image

initialize_from_config()​

classmethod initialize_from_config(hf_vision_config)

source

Initialize Gemma3VisionConfig from HuggingFace vision config.

Parameters:

hf_vision_config (AutoConfig)

Return type:

Gemma3VisionConfig

intermediate_size​

intermediate_size: int

source

Dimension of the MLP representations

layer_norm_eps​

layer_norm_eps: float

source

The epsilon used by the layer normalization layers.

num_attention_heads​

num_attention_heads: int

source

Number of attention heads for each attention layer in the Transformer encoder

num_channels​

num_channels: int

source

Number of channels in the input images.

num_hidden_layers​

num_hidden_layers: int

source

Number of hidden layers in the Transformer encoder

patch_size​

patch_size: int

source

The size (resolution) of each patch

vision_use_head​

vision_use_head: bool = False

source

Flag whether to use attention heads for vision

Gemma3_MultiModalModel​

class max.pipelines.architectures.gemma3multimodal.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextAndVisionContext]

Gemma 3 multimodal pipeline model for text generation.

This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the InternVL model.

Parameters:

Return type:

int

estimate_activation_memory()​

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Estimates the activation memory required for model execution.

This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.

The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.

Parameters:

  • pipeline_config (PipelineConfig) – Pipeline configuration
  • huggingface_config (AutoConfig) – Hugging Face model configuration

Returns:

Estimated activation memory in bytes

Return type:

int

execute()​

execute(model_inputs)

source

If required, execute the vision model, then continue to execute the language model. Either pass through image embeddings or create an empty placeholder.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

get_kv_params()​

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Gets the parameters required to configure the KV cache for InternVL.

Parameters:

Return type:

KVCacheParams

get_num_layers()​

classmethod get_num_layers(huggingface_config)

source

Gets the number of hidden layers from the HuggingFace configuration.

Parameters:

huggingface_config (AutoConfig)

Return type:

int

language_model​

language_model: Model

source

The compiled and initialized MAX Engine model ready for inference.

load_model()​

load_model(session)

source

Loads the compiled Gemma3 MultiModal models into the MAX Engine session.

Returns:

A tuple of (vision_model, language_model).

Parameters:

session (InferenceSession)

Return type:

tuple[Model, Model]

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepare our inputs for the first execution pass of the multimodal model.

Parameters:

Return type:

ModelInputs

prepare_next_token_inputs()​

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the secondary inputs to be passed to execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

Parameters:

Return type:

ModelInputs

vision_model​

vision_model: Model

source

The compiled and initialized MAX Engine vision model ready for inference.