IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.gemma3multimodal

Gemma 3 vision-language architecture for multimodal text generation.

Gemma3ForConditionalGenerationConfig​

class max.pipelines.architectures.gemma3multimodal.Gemma3ForConditionalGenerationConfig(*, boi_token_index, eoi_token_index, devices, dtype, kv_params, image_token_index, initializer_range, interleaved_rope_weights, mm_tokens_per_image, return_logits, tie_word_embeddings, text_config, vision_config, attention_bias=False, quant_config=None, head_dim=256, num_key_value_heads=4)

source

Bases: ArchVLConfigWithTextSubconfig, ArchConfigWithStoredKVParams, ArchConfigWithKVCache

Base configuration for Gemma 3 models.

Contains parameters specific to the Gemma 3 architecture, typically extracted from a HuggingFace configuration object’s text config.

Parameters:

attention_bias​

attention_bias: bool = False

source

Whether to use a bias in the query, key, value and output projection layers during self-attention.

boi_token_index​

boi_token_index: int

source

The begin-of-image token index to wrap the image prompt

devices​

devices: list[DeviceRef]

source

Devices to run the model with.

dtype​

dtype: DType

source

DType of the model weights and input.

eoi_token_index​

eoi_token_index: int

source

The end-of-image token index to wrap the image prompt

finalize()​

finalize(huggingface_config, state_dict, return_logits)

source

Finalize the Gemma3ForConditionalGenerationConfig instance with state_dict dependent fields.

Parameters:

  • huggingface_config (AutoConfig) – HuggingFace model configuration.
  • state_dict (dict[str, WeightData]) – Model weights dictionary.
  • return_logits (ReturnLogits) – Return logits configuration.

Return type:

None

get_num_layers()​

static get_num_layers(huggingface_config)

source

Layer count for the decoder stack (override when HF uses a different field).

Parameters:

huggingface_config (AutoConfig)

Return type:

int

head_dim​

head_dim: int = 256

source

The attention head dimension.

image_token_index​

image_token_index: int

source

The image token index to encode the image prompt

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initializes a Gemma3ForConditionalGenerationConfig instance from pipeline configuration.

Parameters:

Returns:

A Gemma3ForConditionalGenerationConfig instance with fields initialized from config.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes a Gemma3ForConditionalGenerationConfig from pipeline and HuggingFace configs.

This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configurations, without needing the state_dict. Fields that depend on the state_dict should be set via the finalize() method.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – HuggingFace model configuration.

Returns:

A Gemma3ForConditionalGenerationConfig instance ready for finalization.

Return type:

Self

initializer_range​

initializer_range: float

source

Standard deviation for weight initialization.

interleaved_rope_weights​

interleaved_rope_weights: bool

source

True if the rope weights are in interleaved complex format.

kv_params​

kv_params: KVCacheParams

source

KV cache parameters.

mm_tokens_per_image​

mm_tokens_per_image: int

source

The number of tokens per image embedding

num_key_value_heads​

num_key_value_heads: int = 4

source

This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed”

quant_config​

quant_config: QuantConfig | None = None

source

Scaled quantization configuration.

return_logits​

return_logits: ReturnLogits

source

Whether to return the last token, all logits, or a variable number of logits.

text_config​

text_config: Gemma3Config

source

The config object of the text backbone

tie_word_embeddings​

tie_word_embeddings: bool

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

vision_config​

vision_config: Gemma3VisionConfig

source

Custom vision config or dict

Gemma3MultiModalModelInputs​

class max.pipelines.architectures.gemma3multimodal.Gemma3MultiModalModelInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, pixel_values=None, image_token_indices=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Gemma3 multi modal model.

This class encapsulates the input tensors required for the Gemma3 multi modal model, for text and vision processing.

Parameters:

has_vision_inputs​

property has_vision_inputs: bool

source

Check if this input contains vision data.

image_token_indices​

image_token_indices: list[Buffer] | None = None

source

Pre-computed indices of image tokens in the input sequence.

input_row_offsets​

input_row_offsets: ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]

source

Tensor containing the offsets for each row in the ragged input sequence, or the attention mask for the padded input sequence. For distributed execution, this can be a list of tensors, one per device.

pixel_values​

pixel_values: list[Buffer] | None = None

source

[batch, channels, height, width].

Type:

Raw pixel values for vision inputs

return_n_logits​

return_n_logits: Buffer

source

Number of logits to return, used by speculative decoding for example.

signal_buffers​

signal_buffers: list[Buffer]

source

Device buffers used for synchronization in communication collectives.

tokens​

tokens: ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer

source

Tensor containing the input token IDs.

Gemma3VisionConfig​

class max.pipelines.architectures.gemma3multimodal.Gemma3VisionConfig(hidden_act, hidden_size, image_size, intermediate_size, layer_norm_eps, num_attention_heads, num_hidden_layers, num_channels, patch_size, attention_bias=True, attention_dropout=0.0, vision_use_head=False)

source

Bases: object

The vision-specific config for Gemma3 More info at: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json

Parameters:

  • hidden_act (str)
  • hidden_size (int)
  • image_size (int)
  • intermediate_size (int)
  • layer_norm_eps (float)
  • num_attention_heads (int)
  • num_hidden_layers (int)
  • num_channels (int)
  • patch_size (int)
  • attention_bias (bool)
  • attention_dropout (float)
  • vision_use_head (bool)

attention_bias​

attention_bias: bool = True

source

attention_dropout​

attention_dropout: float = 0.0

source

The dropout ratio for the attention probabilities

hidden_act​

hidden_act: str

source

The non-linear activation function (function or string) in the encoder and pooler. β€œgelu”, β€œgelu_tanh”, β€œrelu”, β€œsigmoid”, β€œsilu”, and β€œtanh” are supported.

hidden_size​

hidden_size: int

source

Dimensionality of the encoder layers and the pooler layer

image_size​

image_size: int

source

The size (resolution) of each image

initialize_from_config()​

classmethod initialize_from_config(hf_vision_config)

source

Initialize Gemma3VisionConfig from HuggingFace vision config.

Parameters:

hf_vision_config (AutoConfig)

Return type:

Gemma3VisionConfig

intermediate_size​

intermediate_size: int

source

Dimension of the MLP representations

layer_norm_eps​

layer_norm_eps: float

source

The epsilon used by the layer normalization layers.

num_attention_heads​

num_attention_heads: int

source

Number of attention heads for each attention layer in the Transformer encoder

num_channels​

num_channels: int

source

Number of channels in the input images.

num_hidden_layers​

num_hidden_layers: int

source

Number of hidden layers in the Transformer encoder

patch_size​

patch_size: int

source

The size (resolution) of each patch

vision_use_head​

vision_use_head: bool = False

source

Flag whether to use attention heads for vision

Gemma3_MultiModalModel​

class max.pipelines.architectures.gemma3multimodal.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextAndVisionContext]

Gemma 3 multimodal pipeline model for text generation.

This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

estimate_activation_memory()​

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Estimates the activation memory required for model execution.

This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.

The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.

Parameters:

  • pipeline_config (PipelineConfig) – Pipeline configuration
  • huggingface_config (AutoConfig) – Hugging Face model configuration

Returns:

Estimated activation memory in bytes

Return type:

int

execute()​

execute(model_inputs)

source

If required, execute the vision model, then continue to execute the language model. Either pass through image embeddings or create an empty placeholder.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

get_num_layers()​

classmethod get_num_layers(huggingface_config)

source

Gets the number of hidden layers from the HuggingFace configuration.

Parameters:

huggingface_config (AutoConfig)

Return type:

int

language_model​

language_model: Model

source

The compiled and initialized MAX Engine model ready for inference.

load_model()​

load_model(session)

source

Loads the compiled Gemma3 MultiModal models into the MAX Engine session.

Returns:

A tuple of (vision_model, language_model).

Parameters:

session (InferenceSession)

Return type:

tuple[Model, Model]

model_config_cls​

model_config_cls

source

alias of Gemma3ForConditionalGenerationConfig

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepare our inputs for the first execution pass of the multimodal model.

Parameters:

Return type:

ModelInputs

vision_model​

vision_model: Model

source

The compiled and initialized MAX Engine vision model ready for inference.