Python module
max.pipelines.architectures.gemma3multimodal
Gemma 3 vision-language architecture for multimodal text generation.
Gemma3ForConditionalGenerationConfigβ
class max.pipelines.architectures.gemma3multimodal.Gemma3ForConditionalGenerationConfig(*, boi_token_index, eoi_token_index, devices, dtype, kv_params, image_token_index, initializer_range, interleaved_rope_weights, mm_tokens_per_image, return_logits, tie_word_embeddings, text_config, vision_config, attention_bias=False, quant_config=None, head_dim=256, num_key_value_heads=4)
Bases: ArchConfigWithKVCache
Base configuration for Gemma 3 models.
Contains parameters specific to the Gemma 3 architecture, typically extracted from a HuggingFace configuration objectβs text config.
-
Parameters:
-
- boi_token_index (int)
- eoi_token_index (int)
- devices (list[DeviceRef])
- dtype (DType)
- kv_params (KVCacheParams)
- image_token_index (int)
- initializer_range (float)
- interleaved_rope_weights (bool)
- mm_tokens_per_image (int)
- return_logits (ReturnLogits)
- tie_word_embeddings (bool)
- text_config (Gemma3Config)
- vision_config (Gemma3VisionConfig)
- attention_bias (bool)
- quant_config (QuantConfig | None)
- head_dim (int)
- num_key_value_heads (int)
attention_biasβ
attention_bias: bool = False
Whether to use a bias in the query, key, value and output projection layers during self-attention.
boi_token_indexβ
boi_token_index: int
The begin-of-image token index to wrap the image prompt
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
construct_kv_params()β
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
devicesβ
Devices to run the model with.
dtypeβ
dtype: DType
DType of the model weights and input.
eoi_token_indexβ
eoi_token_index: int
The end-of-image token index to wrap the image prompt
finalize()β
finalize(huggingface_config, state_dict, return_logits)
Finalize the Gemma3ForConditionalGenerationConfig instance with state_dict dependent fields.
-
Parameters:
-
- huggingface_config (AutoConfig) β HuggingFace model configuration.
- state_dict (dict[str, WeightData]) β Model weights dictionary.
- return_logits (ReturnLogits) β Return logits configuration.
-
Return type:
-
None
get_kv_params()β
get_kv_params()
Returns the KV cache parameters.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the maximum sequence length from the embedded text config.
-
Return type:
get_num_layers()β
static get_num_layers(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
head_dimβ
head_dim: int = 256
The attention head dimension.
image_token_indexβ
image_token_index: int
The image token index to encode the image prompt
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initializes a Gemma3ForConditionalGenerationConfig instance from pipeline configuration.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- model_config (MAXModelConfig | None)
-
Returns:
-
A Gemma3ForConditionalGenerationConfig instance with fields initialized from config.
-
Return type:
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config)
Initializes a Gemma3ForConditionalGenerationConfig from pipeline and HuggingFace configs.
This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configurations, without needing the state_dict. Fields that depend on the state_dict should be set via the finalize() method.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) β HuggingFace model configuration.
-
Returns:
-
A Gemma3ForConditionalGenerationConfig instance ready for finalization.
-
Return type:
initializer_rangeβ
initializer_range: float
Standard deviation for weight initialization.
interleaved_rope_weightsβ
interleaved_rope_weights: bool
True if the rope weights are in interleaved complex format.
kv_paramsβ
kv_params: KVCacheParams
KV cache parameters.
mm_tokens_per_imageβ
mm_tokens_per_image: int
The number of tokens per image embedding
num_key_value_headsβ
num_key_value_heads: int = 4
This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructedβ
quant_configβ
quant_config: QuantConfig | None = None
Scaled quantization configuration.
return_logitsβ
return_logits: ReturnLogits
Whether to return the last token, all logits, or a variable number of logits.
text_configβ
text_config: Gemma3Config
The config object of the text backbone
tie_word_embeddingsβ
tie_word_embeddings: bool
Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.
vision_configβ
vision_config: Gemma3VisionConfig
Custom vision config or dict
Gemma3MultiModalModelInputsβ
class max.pipelines.architectures.gemma3multimodal.Gemma3MultiModalModelInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, pixel_values=None, image_token_indices=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
A class representing inputs for the Gemma3 multi modal model.
This class encapsulates the input tensors required for the Gemma3 multi modal model, for text and vision processing.
-
Parameters:
-
- tokens (ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer) β Input token IDs.
- input_row_offsets (ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]) β Input row offsets (ragged tensors).
- return_n_logits (Buffer) β Number of logits to return.
- signal_buffers (list[Buffer]) β Device buffers for distributed communication.
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None) β Inputs for the KV cache.
- pixel_values (list[Buffer] | None) β Raw pixel values for vision inputs. Defaults to
None. - image_token_indices (list[Buffer] | None) β Pre-computed indices of image tokens. Defaults to
None. - lora_ids (Buffer | None)
- lora_ranks (Buffer | None)
- hidden_states (Buffer | list[Buffer] | None)
has_vision_inputsβ
property has_vision_inputs: bool
Check if this input contains vision data.
image_token_indicesβ
Pre-computed indices of image tokens in the input sequence.
input_row_offsetsβ
input_row_offsets: ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]
Tensor containing the offsets for each row in the ragged input sequence, or the attention mask for the padded input sequence. For distributed execution, this can be a list of tensors, one per device.
pixel_valuesβ
[batch, channels, height, width].
-
Type:
-
Raw pixel values for vision inputs
return_n_logitsβ
return_n_logits: Buffer
Number of logits to return, used by speculative decoding for example.
signal_buffersβ
Device buffers used for synchronization in communication collectives.
tokensβ
tokens: ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer
Tensor containing the input token IDs.
Gemma3VisionConfigβ
class max.pipelines.architectures.gemma3multimodal.Gemma3VisionConfig(hidden_act, hidden_size, image_size, intermediate_size, layer_norm_eps, num_attention_heads, num_hidden_layers, num_channels, patch_size, attention_bias=True, attention_dropout=0.0, vision_use_head=False)
Bases: object
The vision-specific config for Gemma3 More info at: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
-
Parameters:
attention_biasβ
attention_bias: bool = True
attention_dropoutβ
attention_dropout: float = 0.0
The dropout ratio for the attention probabilities
hidden_actβ
hidden_act: str
The non-linear activation function (function or string) in the encoder and pooler. βgeluβ, βgelu_tanhβ, βreluβ, βsigmoidβ, βsiluβ, and βtanhβ are supported.
hidden_sizeβ
hidden_size: int
Dimensionality of the encoder layers and the pooler layer
image_sizeβ
image_size: int
The size (resolution) of each image
initialize_from_config()β
classmethod initialize_from_config(hf_vision_config)
Initialize Gemma3VisionConfig from HuggingFace vision config.
-
Parameters:
-
hf_vision_config (AutoConfig)
-
Return type:
intermediate_sizeβ
intermediate_size: int
Dimension of the MLP representations
layer_norm_epsβ
layer_norm_eps: float
The epsilon used by the layer normalization layers.
num_attention_headsβ
num_attention_heads: int
Number of attention heads for each attention layer in the Transformer encoder
num_channelsβ
num_channels: int
Number of channels in the input images.
num_hidden_layersβ
num_hidden_layers: int
Number of hidden layers in the Transformer encoder
patch_sizeβ
patch_size: int
The size (resolution) of each patch
vision_use_headβ
vision_use_head: bool = False
Flag whether to use attention heads for vision
Gemma3_MultiModalModelβ
class max.pipelines.architectures.gemma3multimodal.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextAndVisionContext]
Gemma 3 multimodal pipeline model for text generation.
This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The configuration settings for the entire pipeline.
- session (InferenceSession) β The MAX inference session managing the runtime.
- huggingface_config β The configuration loaded from HuggingFace
(
transformers.AutoConfig). - devices (list[Device]) β A list of MAX devices (
max.driver.Device) to run the model on. - kv_cache_config (KVCacheConfig) β Configuration settings for the Key-Value cache
(
max.pipelines.max_config.KVCacheConfig). - weights (Weights) β The model weights (
max.graph.weights.Weights). - adapter (WeightsAdapter | None) β An optional adapter to modify weights before loading
(
max.graph.weights.WeightsAdapter). - return_logits (ReturnLogits) β The number of top logits to return from the model execution.
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the InternVL model.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
estimate_activation_memory()β
classmethod estimate_activation_memory(pipeline_config, huggingface_config)
Estimates the activation memory required for model execution.
This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.
The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.
-
Parameters:
-
- pipeline_config (PipelineConfig) β Pipeline configuration
- huggingface_config (AutoConfig) β Hugging Face model configuration
-
Returns:
-
Estimated activation memory in bytes
-
Return type:
execute()β
execute(model_inputs)
If required, execute the vision model, then continue to execute the language model. Either pass through image embeddings or create an empty placeholder.
-
Parameters:
-
model_inputs (ModelInputs)
-
Return type:
get_kv_params()β
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Gets the parameters required to configure the KV cache for InternVL.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
get_num_layers()β
classmethod get_num_layers(huggingface_config)
Gets the number of hidden layers from the HuggingFace configuration.
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
language_modelβ
language_model: Model
The compiled and initialized MAX Engine model ready for inference.
load_model()β
load_model(session)
Loads the compiled Gemma3 MultiModal models into the MAX Engine session.
-
Returns:
-
A tuple of (vision_model, language_model).
-
Parameters:
-
session (InferenceSession)
-
Return type:
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepare our inputs for the first execution pass of the multimodal model.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextAndVisionContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
vision_modelβ
vision_model: Model
The compiled and initialized MAX Engine vision model ready for inference.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!