Python module
max.pipelines.architectures.gemma4
Gemma 4 vision-language architecture for multimodal text generation.
Gemma3_MultiModalModel
class max.pipelines.architectures.gemma4.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[Gemma4Context]
Gemma 3 multimodal pipeline model for text generation.
This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The configuration settings for the entire pipeline.
- session (InferenceSession) – The MAX inference session managing the runtime.
- huggingface_config – The configuration loaded from HuggingFace
(
transformers.AutoConfig). - devices (list[Device]) – A list of MAX devices (
max.driver.Device) to run the model on. - kv_cache_config (KVCacheConfig) – Configuration settings for the Key-Value cache
(
max.pipelines.max_config.KVCacheConfig). - weights (Weights) – The model weights (
max.graph.weights.Weights). - adapter (WeightsAdapter | None) – An optional adapter to modify weights before loading
(
max.graph.weights.WeightsAdapter). - return_logits (ReturnLogits) – The number of top logits to return from the model execution.
calculate_max_seq_len()
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the InternVL model.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
estimate_activation_memory()
classmethod estimate_activation_memory(pipeline_config, huggingface_config)
Estimates the activation memory required for model execution.
This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.
The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.
-
Parameters:
-
- pipeline_config (PipelineConfig) – Pipeline configuration
- huggingface_config (AutoConfig) – Hugging Face model configuration
-
Returns:
-
Estimated activation memory in bytes
-
Return type:
execute()
execute(model_inputs)
Execute the vision model (if needed), then the language model.
-
Parameters:
-
model_inputs (ModelInputs)
-
Return type:
get_kv_params()
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Gets the parameters required to configure the KV cache for InternVL.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
language_model
language_model: Model
The compiled and initialized MAX Engine model ready for inference.
load_model()
load_model(session)
Loads the compiled Gemma3 MultiModal models into the MAX Engine session.
-
Returns:
-
A tuple of (vision_model, language_model).
-
Parameters:
-
session (InferenceSession)
-
Return type:
model
property model: Model
Expose language model for graph capture/replay.
Only the language model is captured since vision runs during prefill only.
prepare_initial_token_inputs()
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepare inputs for the first execution pass.
-
Parameters:
-
Return type:
prepare_next_token_inputs()
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
release()
release(request_id)
Release vision encoder cache for a completed request.
-
Parameters:
-
request_id (RequestID)
-
Return type:
-
None
vision_model
vision_model: Model
The compiled and initialized MAX Engine vision model ready for inference.
Gemma4ForConditionalGenerationConfig
class max.pipelines.architectures.gemma4.Gemma4ForConditionalGenerationConfig(*, devices, dtype, kv_params, image_token_index, video_token_index=262144, text_config, vision_config, tie_word_embeddings=False)
Bases: ArchConfigWithKVAndVisionCache
Base configuration for Gemma 4 multimodal models.
This is the top-level config that composes text and vision sub-configs. Model-specific parameters live in the respective sub-configs.
-
Parameters:
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the Gemma 4 model.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
construct_kv_params()
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Constructs KV cache parameters from the top-level HuggingFace config.
-
Parameters:
-
- huggingface_config (AutoConfig) – The top-level HuggingFace config (with
text_config). - pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- devices (list[DeviceRef]) – Target devices for the model.
- kv_cache_config (KVCacheConfig) – KV cache configuration settings.
- cache_dtype (DType) – Data type for the KV cache.
- huggingface_config (AutoConfig) – The top-level HuggingFace config (with
-
Returns:
-
Configured KV cache parameters.
-
Return type:
devices
Devices to run the model with.
dtype
dtype: DType
DType of the model weights and input.
estimate_vision_cache_entry_bytes()
static estimate_vision_cache_entry_bytes(huggingface_config)
Estimate per-entry bytes for the vision encoder cache.
Worst-case tokens per image is
position_embedding_size / pooling_kernel_size², stored at the
text hidden size in bfloat16.
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
finalize()
finalize(huggingface_config, state_dict, return_logits)
Finalize with state_dict-dependent fields.
Parses quantization config from the weights and finalizes the text sub-config.
-
Parameters:
-
- huggingface_config (AutoConfig) – HuggingFace model configuration.
- state_dict (dict[str, WeightData]) – Model weights dictionary.
- return_logits (ReturnLogits) – Return logits configuration.
-
Return type:
-
None
get_kv_params()
get_kv_params()
Returns the KV cache parameters.
-
Return type:
get_max_seq_len()
get_max_seq_len()
Returns the maximum sequence length from the embedded text config.
-
Return type:
image_token_index
image_token_index: int
The image token index to encode the image prompt.
initialize()
classmethod initialize(pipeline_config, model_config=None)
Initializes from pipeline configuration.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- model_config (MAXModelConfig | None) – Optional model config override.
-
Returns:
-
An initialized config instance.
-
Return type:
initialize_from_config()
classmethod initialize_from_config(pipeline_config, huggingface_config)
Initializes from pipeline and HuggingFace configs.
Fields that depend on the state_dict should be set via finalize().
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – Top-level HuggingFace model configuration.
-
Returns:
-
A config instance ready for finalization.
-
Return type:
kv_params
kv_params: MultiKVCacheParams
KV cache parameters.
text_config
text_config: Gemma4TextConfig
The config object of the text backbone.
tie_word_embeddings
tie_word_embeddings: bool = False
Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.
video_token_index
video_token_index: int = 262144
The video token index to encode the video prompt.
vision_config
vision_config: Gemma4VisionConfig
The config object of the vision encoder.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!