Skip to main content

Python module

max.pipelines.architectures.gemma4

Gemma 4 vision-language architecture for multimodal text generation.

Gemma3MultiModalModelInputs​

class max.pipelines.architectures.gemma4.Gemma3MultiModalModelInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, images=None, video=None, combined_embeds=None, combined_indices=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Gemma3 multi modal model.

This class encapsulates the input tensors required for the Gemma3 multi modal model, for text and vision processing.

Parameters:

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for the language model ABI.

combined_embeds​

combined_embeds: list[Buffer] | None = None

source

combined_indices​

combined_indices: list[Buffer] | None = None

source

images​

images: ImageInputs | None = None

source

input_row_offsets​

input_row_offsets: ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]

source

return_n_logits​

return_n_logits: Buffer

source

signal_buffers​

signal_buffers: list[Buffer]

source

tokens​

tokens: ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer

source

video​

video: VideoInputs | None = None

source

Gemma3_MultiModalModel​

class max.pipelines.architectures.gemma4.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[Gemma4Context]

Gemma 3 multimodal pipeline model for text generation.

This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the InternVL model.

Parameters:

Return type:

int

estimate_activation_memory()​

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Estimates the activation memory required for model execution.

This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.

The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.

Parameters:

  • pipeline_config (PipelineConfig) – Pipeline configuration
  • huggingface_config (AutoConfig) – Hugging Face model configuration

Returns:

Estimated activation memory in bytes

Return type:

int

execute()​

execute(model_inputs)

source

Execute the vision model (if needed), then the language model.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

get_kv_params()​

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Gets the parameters required to configure the KV cache for InternVL.

Parameters:

Return type:

MultiKVCacheParams

language_model​

language_model: Model

source

The compiled and initialized MAX Engine model ready for inference.

load_model()​

load_model(session)

source

Loads the compiled Gemma3 MultiModal models into the MAX Engine session.

Returns:

A tuple of (vision_model, language_model).

Parameters:

session (InferenceSession)

Return type:

tuple[Model, Model]

model​

property model: Model

source

Expose language model for graph capture/replay.

Only the language model is captured since vision runs during prefill only.

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepare inputs for the first execution pass.

Parameters:

Return type:

ModelInputs

prepare_next_token_inputs()​

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the secondary inputs to be passed to execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

Parameters:

Return type:

ModelInputs

release()​

release(request_id)

source

Release vision encoder cache for a completed request.

Parameters:

request_id (RequestID)

Return type:

None

vision_model​

vision_model: Model

source

The compiled and initialized MAX Engine vision model ready for inference.

Gemma4ForConditionalGenerationConfig​

class max.pipelines.architectures.gemma4.Gemma4ForConditionalGenerationConfig(*, devices, dtype, kv_params, image_token_index, video_token_index=262144, text_config, vision_config, tie_word_embeddings=False)

source

Bases: ArchConfigWithKVAndVisionCache

Base configuration for Gemma 4 multimodal models.

This is the top-level config that composes text and vision sub-configs. Model-specific parameters live in the respective sub-configs.

Parameters:

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the Gemma 4 model.

Parameters:

Return type:

int

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs KV cache parameters from the top-level HuggingFace config.

Parameters:

  • huggingface_config (AutoConfig) – The top-level HuggingFace config (with text_config).
  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • devices (list[DeviceRef]) – Target devices for the model.
  • kv_cache_config (KVCacheConfig) – KV cache configuration settings.
  • cache_dtype (DType) – Data type for the KV cache.

Returns:

Configured KV cache parameters.

Return type:

MultiKVCacheParams

devices​

devices: list[DeviceRef]

source

Devices to run the model with.

dtype​

dtype: DType

source

DType of the model weights and input.

estimate_vision_cache_entry_bytes()​

static estimate_vision_cache_entry_bytes(huggingface_config)

source

Estimate per-entry bytes for the vision encoder cache.

Worst-case tokens per image is position_embedding_size / pooling_kernel_sizeΒ², stored at the text hidden size in bfloat16.

Parameters:

huggingface_config (AutoConfig)

Return type:

int

finalize()​

finalize(huggingface_config, state_dict, return_logits)

source

Finalize with state_dict-dependent fields.

Parses quantization config from the weights and finalizes the text sub-config.

Parameters:

  • huggingface_config (AutoConfig) – HuggingFace model configuration.
  • state_dict (dict[str, WeightData]) – Model weights dictionary.
  • return_logits (ReturnLogits) – Return logits configuration.

Return type:

None

get_kv_params()​

get_kv_params()

source

Returns the KV cache parameters.

Return type:

MultiKVCacheParams

get_max_seq_len()​

get_max_seq_len()

source

Returns the maximum sequence length from the embedded text config.

Return type:

int

image_token_index​

image_token_index: int

source

The image token index to encode the image prompt.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initializes from pipeline configuration.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • model_config (MAXModelConfig | None) – Optional model config override.

Returns:

An initialized config instance.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes from pipeline and HuggingFace configs.

Fields that depend on the state_dict should be set via finalize().

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – Top-level HuggingFace model configuration.

Returns:

A config instance ready for finalization.

Return type:

Self

kv_params​

kv_params: MultiKVCacheParams

source

KV cache parameters.

text_config​

text_config: Gemma4TextConfig

source

The config object of the text backbone.

tie_word_embeddings​

tie_word_embeddings: bool = False

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

video_token_index​

video_token_index: int = 262144

source

The video token index to encode the video prompt.

vision_config​

vision_config: Gemma4VisionConfig

source

The config object of the vision encoder.

Gemma4TextConfig​

class max.pipelines.architectures.gemma4.Gemma4TextConfig(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar=None, sliding_window, final_logit_softcapping, attn_logit_softcapping=None, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None, vocab_size_per_layer_input=262144, hidden_size_per_layer_input=0, num_global_key_value_heads=4, global_head_dim=512, attention_k_eq_v=True, num_kv_shared_layers=0, enable_moe_block=False, use_double_wide_mlp=False, num_experts=0, top_k_experts=0, moe_intermediate_size=0, global_rope_scaling=None, global_rope_theta=1000000.0, sliding_window_rope_theta=10000.0, layer_types, max_seq_len)

source

Bases: Gemma3Config

Text configuration for Gemma 4 models.

Defaults are from gemma-4-31b-it.

The following parameters are ignored:
  • routed_layer_pattern
    • stream_and_decode_in_f32

    Parameters:

    • vocab_size (int)
    • hidden_size (int)
    • intermediate_size (int)
    • num_hidden_layers (int)
    • num_attention_heads (int)
    • num_key_value_heads (int)
    • head_dim (int)
    • hidden_activation (str)
    • max_position_embeddings (int)
    • rms_norm_eps (float)
    • rope_theta (float)
    • attention_bias (bool)
    • query_pre_attn_scalar (float | None)
    • sliding_window (int)
    • final_logit_softcapping (float | None)
    • attn_logit_softcapping (int | None)
    • rope_scaling (LinearScalingParams | None)
    • rope_local_base_freq (float)
    • sliding_window_pattern (int)
    • dtype (DType)
    • devices (list[DeviceRef])
    • interleaved_rope_weights (bool)
    • return_logits (ReturnLogits)
    • kv_params (KVCacheParams)
    • tie_word_embeddings (bool)
    • quant_config (QuantConfig | None)
    • vocab_size_per_layer_input (int)
    • hidden_size_per_layer_input (int)
    • num_global_key_value_heads (int)
    • global_head_dim (int)
    • attention_k_eq_v (bool)
    • num_kv_shared_layers (int)
    • enable_moe_block (bool)
    • use_double_wide_mlp (bool)
    • num_experts (int)
    • top_k_experts (int)
    • moe_intermediate_size (int)
    • global_rope_scaling (ProportionalScalingParams | None)
    • global_rope_theta (float)
    • sliding_window_rope_theta (float)
    • layer_types (list[str])
    • max_seq_len (int)

    attention_k_eq_v​

    attention_k_eq_v: bool = True

    source

    If the key and value projections are the same.

    When true, the checkpoint will not contain v_proj and v_norm weights.

    attn_logit_softcapping​

    attn_logit_softcapping: int | None = None

    source

    Scaling factor when applying tanh softcapping on the attention scores.

    calculate_max_seq_len()​

    static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)

    source

    Calculates the maximum sequence length for the model.

    Uses the max_length from the max.pipelines.config.PipelineConfig if provided, otherwise falls back to the max_position_embeddings from the HuggingFace configuration’s text config.

    Parameters:

    • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
    • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).
    • model_config (MAXModelConfig | None)

    Returns:

    The calculated maximum sequence length.

    Return type:

    int

    enable_moe_block​

    enable_moe_block: bool = False

    source

    If the model uses MOE.

    get_max_seq_len()​

    get_max_seq_len()

    source

    Returns the default maximum sequence length for the model.

    Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

    Return type:

    int

    global_head_dim​

    global_head_dim: int = 512

    source

    Head dimension used in full attention layers.

    global_rope_scaling​

    global_rope_scaling: ProportionalScalingParams | None = None

    source

    Scaling configuration for the RoPE embeddings used in global attention.

    global_rope_theta​

    global_rope_theta: float = 1000000.0

    source

    Rope theta used for the RoPE embeddings used in global attention.

    hidden_size_per_layer_input​

    hidden_size_per_layer_input: int = 0

    source

    Hidden size output of the per-layer input embedding. When this is 0, the per-layer input embedding is not used.

    initialize_from_config()​

    classmethod initialize_from_config(pipeline_config, huggingface_config)

    source

    Initialize Gemma4TextConfig from pipeline and HuggingFace configs.

    Parameters:

    • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
    • huggingface_config (AutoConfig) – HuggingFace text model configuration.

    Returns:

    An initialized Gemma4TextConfig instance.

    Return type:

    Self

    layer_types​

    layer_types: list[str]

    source

    max_seq_len​

    max_seq_len: int

    source

    actual max seq length determined by calculate_max_seq_len

    moe_intermediate_size​

    moe_intermediate_size: int = 0

    source

    Hidden dimension of each MoE expert’s feed-forward block.

    num_experts​

    num_experts: int = 0

    source

    Total number of MoE experts.

    num_global_key_value_heads​

    num_global_key_value_heads: int = 4

    source

    Number of key value heads used in full attention layers.

    num_kv_shared_layers​

    num_kv_shared_layers: int = 0

    source

    An optimization used in smaller models to share the kv cache across layers.

    query_pre_attn_scalar​

    query_pre_attn_scalar: float | None = None

    source

    Scaling factor used on the attention scores.

    rope_scaling​

    property rope_scaling: ProportionalScalingParams | None

    source

    rope_theta​

    property rope_theta: float

    source

    sliding_window_rope_theta​

    sliding_window_rope_theta: float = 10000.0

    source

    Rope theta used for the RoPE embeddings used in sliding window attention.

    top_k_experts​

    top_k_experts: int = 0

    source

    Number of experts selected per token by the router.

    use_double_wide_mlp​

    use_double_wide_mlp: bool = False

    source

    If the model uses a double wide MLP.

    vocab_size_per_layer_input​

    vocab_size_per_layer_input: int = 262144

    source

    Vocab size used in the per-layer input embedding (used in smaller architectures).

    Gemma4VisionConfig​

    class max.pipelines.architectures.gemma4.Gemma4VisionConfig(hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, rms_norm_eps, max_position_embeddings, patch_size, position_embedding_size, pooling_kernel_size, standardize=False, attention_bias=False, attention_dropout=0.0, use_bidirectional_attention='vision', layer_types=None, use_clipped_linears=False, rope_theta=100.0)

    source

    Bases: object

    Vision-specific configuration for Gemma 4 models.

    Parameters:

    • hidden_size (int)
    • intermediate_size (int)
    • num_hidden_layers (int)
    • num_attention_heads (int)
    • num_key_value_heads (int)
    • head_dim (int)
    • hidden_activation (str)
    • rms_norm_eps (float)
    • max_position_embeddings (int)
    • patch_size (int)
    • position_embedding_size (int)
    • pooling_kernel_size (int)
    • standardize (bool)
    • attention_bias (bool)
    • attention_dropout (float)
    • use_bidirectional_attention (str | None)
    • layer_types (list[str] | None)
    • use_clipped_linears (bool)
    • rope_theta (float)

    attention_bias​

    attention_bias: bool = False

    source

    Whether to use bias in attention projection layers.

    attention_dropout​

    attention_dropout: float = 0.0

    source

    The dropout ratio for the attention probabilities.

    head_dim​

    head_dim: int

    source

    Dimension of each attention head.

    hidden_activation​

    hidden_activation: str

    source

    The non-linear activation function in the encoder.

    hidden_size​

    hidden_size: int

    source

    Dimensionality of the encoder layers.

    initialize_from_config()​

    classmethod initialize_from_config(hf_vision_config)

    source

    Initialize Gemma4VisionConfig from a HuggingFace vision config.

    Parameters:

    hf_vision_config (AutoConfig) – The HuggingFace vision configuration object.

    Returns:

    An initialized Gemma4VisionConfig instance.

    Return type:

    Gemma4VisionConfig

    intermediate_size​

    intermediate_size: int

    source

    Dimension of the MLP representations.

    layer_types​

    layer_types: list[str] | None = None

    source

    Per-layer attention type specification (e.g. "full_attention").

    max_position_embeddings​

    max_position_embeddings: int

    source

    The maximum sequence length supported by position embeddings.

    num_attention_heads​

    num_attention_heads: int

    source

    Number of attention heads for each attention layer.

    num_hidden_layers​

    num_hidden_layers: int

    source

    Number of hidden layers in the vision Transformer encoder.

    num_key_value_heads​

    num_key_value_heads: int

    source

    Number of key-value heads for grouped-query attention.

    patch_size​

    patch_size: int

    source

    The size (resolution) of each patch.

    pooling_kernel_size​

    pooling_kernel_size: int

    source

    Kernel size for spatial pooling.

    position_embedding_size​

    position_embedding_size: int

    source

    Size of the position embedding table.

    rms_norm_eps​

    rms_norm_eps: float

    source

    The epsilon used by the RMS normalization layers.

    rope_theta​

    rope_theta: float = 100.0

    source

    standardize​

    standardize: bool = False

    source

    Whether to standardize the image features.

    use_bidirectional_attention​

    use_bidirectional_attention: str | None = 'vision'

    source

    Controls bidirectional attention scope. "all", "vision", or None.

    use_clipped_linears​

    use_clipped_linears: bool = False

    source

    Whether to use clipped linear layers.