Python module
max.pipelines.architectures.gemma4
Gemma 4 vision-language architecture for multimodal text generation.
Gemma3MultiModalModelInputsβ
class max.pipelines.architectures.gemma4.Gemma3MultiModalModelInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, images=None, video=None, combined_embeds=None, combined_indices=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
A class representing inputs for the Gemma3 multi modal model.
This class encapsulates the input tensors required for the Gemma3 multi modal model, for text and vision processing.
-
Parameters:
-
- tokens (ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer) β Input token IDs.
- input_row_offsets (ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]) β Input row offsets (ragged tensors).
- return_n_logits (Buffer) β Number of logits to return.
- signal_buffers (list[Buffer]) β Device buffers for distributed communication.
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None) β Combined KV cache inputs (sliding-window + global).
- images (ImageInputs | None) β Inputs to the image encoder.
- video (VideoInputs | None) β Inputs to the video encoder.
- combined_embeds (list[Buffer] | None)
- combined_indices (list[Buffer] | None)
- lora_ids (Buffer | None)
- lora_ranks (Buffer | None)
- hidden_states (Buffer | list[Buffer] | None)
buffersβ
Returns positional Buffer inputs for the language model ABI.
combined_embedsβ
combined_indicesβ
imagesβ
images: ImageInputs | None = None
input_row_offsetsβ
input_row_offsets: ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]
return_n_logitsβ
return_n_logits: Buffer
signal_buffersβ
tokensβ
tokens: ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer
videoβ
video: VideoInputs | None = None
Gemma3_MultiModalModelβ
class max.pipelines.architectures.gemma4.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[Gemma4Context]
Gemma 3 multimodal pipeline model for text generation.
This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The configuration settings for the entire pipeline.
- session (InferenceSession) β The MAX inference session managing the runtime.
- huggingface_config β The configuration loaded from HuggingFace
(
transformers.AutoConfig). - devices (list[Device]) β A list of MAX devices (
max.driver.Device) to run the model on. - kv_cache_config (KVCacheConfig) β Configuration settings for the Key-Value cache
(
max.pipelines.max_config.KVCacheConfig). - weights (Weights) β The model weights (
max.graph.weights.Weights). - adapter (WeightsAdapter | None) β An optional adapter to modify weights before loading
(
max.graph.weights.WeightsAdapter). - return_logits (ReturnLogits) β The number of top logits to return from the model execution.
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the InternVL model.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
estimate_activation_memory()β
classmethod estimate_activation_memory(pipeline_config, huggingface_config)
Estimates the activation memory required for model execution.
This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.
The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.
-
Parameters:
-
- pipeline_config (PipelineConfig) β Pipeline configuration
- huggingface_config (AutoConfig) β Hugging Face model configuration
-
Returns:
-
Estimated activation memory in bytes
-
Return type:
execute()β
execute(model_inputs)
Execute the vision model (if needed), then the language model.
-
Parameters:
-
model_inputs (ModelInputs)
-
Return type:
get_kv_params()β
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Gets the parameters required to configure the KV cache for InternVL.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
language_modelβ
language_model: Model
The compiled and initialized MAX Engine model ready for inference.
load_model()β
load_model(session)
Loads the compiled Gemma3 MultiModal models into the MAX Engine session.
-
Returns:
-
A tuple of (vision_model, language_model).
-
Parameters:
-
session (InferenceSession)
-
Return type:
modelβ
property model: Model
Expose language model for graph capture/replay.
Only the language model is captured since vision runs during prefill only.
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepare inputs for the first execution pass.
-
Parameters:
-
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
release()β
release(request_id)
Release vision encoder cache for a completed request.
-
Parameters:
-
request_id (RequestID)
-
Return type:
-
None
vision_modelβ
vision_model: Model
The compiled and initialized MAX Engine vision model ready for inference.
Gemma4ForConditionalGenerationConfigβ
class max.pipelines.architectures.gemma4.Gemma4ForConditionalGenerationConfig(*, devices, dtype, kv_params, image_token_index, video_token_index=262144, text_config, vision_config, tie_word_embeddings=False)
Bases: ArchConfigWithKVAndVisionCache
Base configuration for Gemma 4 multimodal models.
This is the top-level config that composes text and vision sub-configs. Model-specific parameters live in the respective sub-configs.
-
Parameters:
-
- devices (list[DeviceRef])
- dtype (DType)
- kv_params (MultiKVCacheParams)
- image_token_index (int)
- video_token_index (int)
- text_config (Gemma4TextConfig)
- vision_config (Gemma4VisionConfig)
- tie_word_embeddings (bool)
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the Gemma 4 model.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
construct_kv_params()β
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Constructs KV cache parameters from the top-level HuggingFace config.
-
Parameters:
-
- huggingface_config (AutoConfig) β The top-level HuggingFace config (with
text_config). - pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- devices (list[DeviceRef]) β Target devices for the model.
- kv_cache_config (KVCacheConfig) β KV cache configuration settings.
- cache_dtype (DType) β Data type for the KV cache.
- huggingface_config (AutoConfig) β The top-level HuggingFace config (with
-
Returns:
-
Configured KV cache parameters.
-
Return type:
devicesβ
Devices to run the model with.
dtypeβ
dtype: DType
DType of the model weights and input.
estimate_vision_cache_entry_bytes()β
static estimate_vision_cache_entry_bytes(huggingface_config)
Estimate per-entry bytes for the vision encoder cache.
Worst-case tokens per image is
position_embedding_size / pooling_kernel_sizeΒ², stored at the
text hidden size in bfloat16.
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
finalize()β
finalize(huggingface_config, state_dict, return_logits)
Finalize with state_dict-dependent fields.
Parses quantization config from the weights and finalizes the text sub-config.
-
Parameters:
-
- huggingface_config (AutoConfig) β HuggingFace model configuration.
- state_dict (dict[str, WeightData]) β Model weights dictionary.
- return_logits (ReturnLogits) β Return logits configuration.
-
Return type:
-
None
get_kv_params()β
get_kv_params()
Returns the KV cache parameters.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the maximum sequence length from the embedded text config.
-
Return type:
image_token_indexβ
image_token_index: int
The image token index to encode the image prompt.
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initializes from pipeline configuration.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- model_config (MAXModelConfig | None) β Optional model config override.
-
Returns:
-
An initialized config instance.
-
Return type:
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config)
Initializes from pipeline and HuggingFace configs.
Fields that depend on the state_dict should be set via finalize().
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) β Top-level HuggingFace model configuration.
-
Returns:
-
A config instance ready for finalization.
-
Return type:
kv_paramsβ
kv_params: MultiKVCacheParams
KV cache parameters.
text_configβ
text_config: Gemma4TextConfig
The config object of the text backbone.
tie_word_embeddingsβ
tie_word_embeddings: bool = False
Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.
video_token_indexβ
video_token_index: int = 262144
The video token index to encode the video prompt.
vision_configβ
vision_config: Gemma4VisionConfig
The config object of the vision encoder.
Gemma4TextConfigβ
class max.pipelines.architectures.gemma4.Gemma4TextConfig(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar=None, sliding_window, final_logit_softcapping, attn_logit_softcapping=None, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None, vocab_size_per_layer_input=262144, hidden_size_per_layer_input=0, num_global_key_value_heads=4, global_head_dim=512, attention_k_eq_v=True, num_kv_shared_layers=0, enable_moe_block=False, use_double_wide_mlp=False, num_experts=0, top_k_experts=0, moe_intermediate_size=0, global_rope_scaling=None, global_rope_theta=1000000.0, sliding_window_rope_theta=10000.0, layer_types, max_seq_len)
Bases: Gemma3Config
Text configuration for Gemma 4 models.
Defaults are from gemma-4-31b-it.
- The following parameters are ignored:
- routed_layer_pattern
- stream_and_decode_in_f32
-
Parameters:
-
- vocab_size (int)
- hidden_size (int)
- intermediate_size (int)
- num_hidden_layers (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- head_dim (int)
- hidden_activation (str)
- max_position_embeddings (int)
- rms_norm_eps (float)
- rope_theta (float)
- attention_bias (bool)
- query_pre_attn_scalar (float | None)
- sliding_window (int)
- final_logit_softcapping (float | None)
- attn_logit_softcapping (int | None)
- rope_scaling (LinearScalingParams | None)
- rope_local_base_freq (float)
- sliding_window_pattern (int)
- dtype (DType)
- devices (list[DeviceRef])
- interleaved_rope_weights (bool)
- return_logits (ReturnLogits)
- kv_params (KVCacheParams)
- tie_word_embeddings (bool)
- quant_config (QuantConfig | None)
- vocab_size_per_layer_input (int)
- hidden_size_per_layer_input (int)
- num_global_key_value_heads (int)
- global_head_dim (int)
- attention_k_eq_v (bool)
- num_kv_shared_layers (int)
- enable_moe_block (bool)
- use_double_wide_mlp (bool)
- num_experts (int)
- top_k_experts (int)
- moe_intermediate_size (int)
- global_rope_scaling (ProportionalScalingParams | None)
- global_rope_theta (float)
- sliding_window_rope_theta (float)
- layer_types (list[str])
- max_seq_len (int)
attention_k_eq_vβ
attention_k_eq_v: bool = True
If the key and value projections are the same.
When true, the checkpoint will not contain v_proj and v_norm weights.
attn_logit_softcappingβ
Scaling factor when applying tanh softcapping on the attention scores.
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)
Calculates the maximum sequence length for the model.
Uses the max_length from the max.pipelines.config.PipelineConfig if provided,
otherwise falls back to the max_position_embeddings from the HuggingFace
configurationβs text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) β The HuggingFace model configuration object (
transformers.AutoConfig). - model_config (MAXModelConfig | None)
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
enable_moe_blockβ
enable_moe_block: bool = False
If the model uses MOE.
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
global_head_dimβ
global_head_dim: int = 512
Head dimension used in full attention layers.
global_rope_scalingβ
global_rope_scaling: ProportionalScalingParams | None = None
Scaling configuration for the RoPE embeddings used in global attention.
global_rope_thetaβ
global_rope_theta: float = 1000000.0
Rope theta used for the RoPE embeddings used in global attention.
hidden_size_per_layer_inputβ
hidden_size_per_layer_input: int = 0
Hidden size output of the per-layer input embedding. When this is 0, the per-layer input embedding is not used.
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config)
Initialize Gemma4TextConfig from pipeline and HuggingFace configs.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) β HuggingFace text model configuration.
-
Returns:
-
An initialized Gemma4TextConfig instance.
-
Return type:
layer_typesβ
max_seq_lenβ
max_seq_len: int
actual max seq length determined by calculate_max_seq_len
moe_intermediate_sizeβ
moe_intermediate_size: int = 0
Hidden dimension of each MoE expertβs feed-forward block.
num_expertsβ
num_experts: int = 0
Total number of MoE experts.
num_global_key_value_headsβ
num_global_key_value_heads: int = 4
Number of key value heads used in full attention layers.
num_kv_shared_layersβ
num_kv_shared_layers: int = 0
An optimization used in smaller models to share the kv cache across layers.
query_pre_attn_scalarβ
Scaling factor used on the attention scores.
rope_scalingβ
property rope_scaling: ProportionalScalingParams | None
rope_thetaβ
property rope_theta: float
sliding_window_rope_thetaβ
sliding_window_rope_theta: float = 10000.0
Rope theta used for the RoPE embeddings used in sliding window attention.
top_k_expertsβ
top_k_experts: int = 0
Number of experts selected per token by the router.
use_double_wide_mlpβ
use_double_wide_mlp: bool = False
If the model uses a double wide MLP.
vocab_size_per_layer_inputβ
vocab_size_per_layer_input: int = 262144
Vocab size used in the per-layer input embedding (used in smaller architectures).
Gemma4VisionConfigβ
class max.pipelines.architectures.gemma4.Gemma4VisionConfig(hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, rms_norm_eps, max_position_embeddings, patch_size, position_embedding_size, pooling_kernel_size, standardize=False, attention_bias=False, attention_dropout=0.0, use_bidirectional_attention='vision', layer_types=None, use_clipped_linears=False, rope_theta=100.0)
Bases: object
Vision-specific configuration for Gemma 4 models.
-
Parameters:
-
- hidden_size (int)
- intermediate_size (int)
- num_hidden_layers (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- head_dim (int)
- hidden_activation (str)
- rms_norm_eps (float)
- max_position_embeddings (int)
- patch_size (int)
- position_embedding_size (int)
- pooling_kernel_size (int)
- standardize (bool)
- attention_bias (bool)
- attention_dropout (float)
- use_bidirectional_attention (str | None)
- layer_types (list[str] | None)
- use_clipped_linears (bool)
- rope_theta (float)
attention_biasβ
attention_bias: bool = False
Whether to use bias in attention projection layers.
attention_dropoutβ
attention_dropout: float = 0.0
The dropout ratio for the attention probabilities.
head_dimβ
head_dim: int
Dimension of each attention head.
hidden_activationβ
hidden_activation: str
The non-linear activation function in the encoder.
hidden_sizeβ
hidden_size: int
Dimensionality of the encoder layers.
initialize_from_config()β
classmethod initialize_from_config(hf_vision_config)
Initialize Gemma4VisionConfig from a HuggingFace vision config.
-
Parameters:
-
hf_vision_config (AutoConfig) β The HuggingFace vision configuration object.
-
Returns:
-
An initialized Gemma4VisionConfig instance.
-
Return type:
intermediate_sizeβ
intermediate_size: int
Dimension of the MLP representations.
layer_typesβ
Per-layer attention type specification (e.g. "full_attention").
max_position_embeddingsβ
max_position_embeddings: int
The maximum sequence length supported by position embeddings.
num_attention_headsβ
num_attention_heads: int
Number of attention heads for each attention layer.
num_hidden_layersβ
num_hidden_layers: int
Number of hidden layers in the vision Transformer encoder.
num_key_value_headsβ
num_key_value_heads: int
Number of key-value heads for grouped-query attention.
patch_sizeβ
patch_size: int
The size (resolution) of each patch.
pooling_kernel_sizeβ
pooling_kernel_size: int
Kernel size for spatial pooling.
position_embedding_sizeβ
position_embedding_size: int
Size of the position embedding table.
rms_norm_epsβ
rms_norm_eps: float
The epsilon used by the RMS normalization layers.
rope_thetaβ
rope_theta: float = 100.0
standardizeβ
standardize: bool = False
Whether to standardize the image features.
use_bidirectional_attentionβ
Controls bidirectional attention scope. "all", "vision", or None.
use_clipped_linearsβ
use_clipped_linears: bool = False
Whether to use clipped linear layers.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!