IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.gemma4

Gemma 4 vision-language architecture for multimodal text generation.

Gemma3MultiModalModelInputs​

class max.pipelines.architectures.gemma4.Gemma3MultiModalModelInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, images=None, video=None, combined_embeds=None, combined_indices=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Gemma3 multi modal model.

This class encapsulates the input tensors required for the Gemma3 multi modal model, for text and vision processing.

Parameters:

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for the language model ABI.

combined_embeds​

combined_embeds: list[Buffer] | None = None

source

combined_indices​

combined_indices: list[Buffer] | None = None

source

images​

images: ImageInputs | None = None

source

input_row_offsets​

input_row_offsets: ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]

source

return_n_logits​

return_n_logits: Buffer

source

signal_buffers​

signal_buffers: list[Buffer]

source

tokens​

tokens: ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer

source

video​

video: VideoInputs | None = None

source

Gemma3_MultiModalModel​

class max.pipelines.architectures.gemma4.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[Gemma4Context]

Gemma 3 multimodal pipeline model for text generation.

This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the InternVL model.

Parameters:

Return type:

int

estimate_activation_memory()​

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Estimates the activation memory required for model execution.

This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.

The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.

Parameters:

  • pipeline_config (PipelineConfig) – Pipeline configuration
  • huggingface_config (AutoConfig) – Hugging Face model configuration

Returns:

Estimated activation memory in bytes

Return type:

int

execute()​

execute(model_inputs)

source

Execute the vision model (if needed), then the language model.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

get_kv_params()​

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Gets the parameters required to configure the KV cache for InternVL.

Parameters:

Return type:

MultiKVCacheParams

language_model​

language_model: Model

source

The compiled and initialized MAX Engine model ready for inference.

load_model()​

load_model(session)

source

Loads the compiled Gemma3 MultiModal models into the MAX Engine session.

Returns:

A tuple of (vision_model, language_model).

Parameters:

session (InferenceSession)

Return type:

tuple[Model, Model]

model​

property model: Model

source

Expose language model for graph capture/replay.

Only the language model is captured since vision runs during prefill only.

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepare inputs for the first execution pass.

Parameters:

Return type:

ModelInputs

prepare_next_token_inputs()​

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the secondary inputs to be passed to execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

Parameters:

Return type:

ModelInputs

release()​

release(request_id)

source

Release vision encoder cache for a completed request.

Parameters:

request_id (RequestID)

Return type:

None

vision_model​

vision_model: Model

source

The compiled and initialized MAX Engine vision model ready for inference.

Gemma4ForConditionalGenerationConfig​

class max.pipelines.architectures.gemma4.Gemma4ForConditionalGenerationConfig(*, devices, dtype, kv_params, image_token_index, video_token_index=262144, text_config, vision_config, tie_word_embeddings=False)

source

Bases: ArchConfigWithKVAndVisionCache

Base configuration for Gemma 4 multimodal models.

This is the top-level config that composes text and vision sub-configs. Model-specific parameters live in the respective sub-configs.

Parameters:

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the Gemma 4 model.

Parameters:

Return type:

int

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs KV cache parameters from the top-level HuggingFace config.

Parameters:

  • huggingface_config (AutoConfig) – The top-level HuggingFace config (with text_config).
  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • devices (list[DeviceRef]) – Target devices for the model.
  • kv_cache_config (KVCacheConfig) – KV cache configuration settings.
  • cache_dtype (DType) – Data type for the KV cache.

Returns:

Configured KV cache parameters.

Return type:

MultiKVCacheParams

devices​

devices: list[DeviceRef]

source

Devices to run the model with.

dtype​

dtype: DType

source

DType of the model weights and input.

estimate_vision_cache_entry_bytes()​

static estimate_vision_cache_entry_bytes(huggingface_config)

source

Estimate per-entry bytes for the vision encoder cache.

Worst-case tokens per image is position_embedding_size / pooling_kernel_sizeΒ², stored at the text hidden size in bfloat16.

Parameters:

huggingface_config (AutoConfig)

Return type:

int

finalize()​

finalize(huggingface_config, state_dict, return_logits)

source

Finalize with state_dict-dependent fields.

Parses quantization config from the weights and finalizes the text sub-config.

Parameters:

  • huggingface_config (AutoConfig) – HuggingFace model configuration.
  • state_dict (dict[str, WeightData]) – Model weights dictionary.
  • return_logits (ReturnLogits) – Return logits configuration.

Return type:

None

get_kv_params()​

get_kv_params()

source

Returns the KV cache parameters.

Return type:

MultiKVCacheParams

get_max_seq_len()​

get_max_seq_len()

source

Returns the maximum sequence length from the embedded text config.

Return type:

int

image_token_index​

image_token_index: int

source

The image token index to encode the image prompt.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initializes from pipeline configuration.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • model_config (MAXModelConfig | None) – Optional model config override.

Returns:

An initialized config instance.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes from pipeline and HuggingFace configs.

Fields that depend on the state_dict should be set via finalize().

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – Top-level HuggingFace model configuration.

Returns:

A config instance ready for finalization.

Return type:

Self

kv_params​

kv_params: MultiKVCacheParams

source

KV cache parameters.

text_config​

text_config: Gemma4TextConfig

source

The config object of the text backbone.

tie_word_embeddings​

tie_word_embeddings: bool = False

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

video_token_index​

video_token_index: int = 262144

source

The video token index to encode the video prompt.

vision_config​

vision_config: Gemma4VisionConfig

source

The config object of the vision encoder.

Gemma4TextConfig​

class max.pipelines.architectures.gemma4.Gemma4TextConfig(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar=None, sliding_window, final_logit_softcapping, attn_logit_softcapping=None, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None, vocab_size_per_layer_input=262144, hidden_size_per_layer_input=0, num_global_key_value_heads=4, global_head_dim=512, attention_k_eq_v=True, num_kv_shared_layers=0, enable_moe_block=False, use_double_wide_mlp=False, num_experts=0, top_k_experts=0, moe_intermediate_size=0, global_rope_scaling=None, global_rope_theta=1000000.0, sliding_window_rope_theta=10000.0, layer_types, max_seq_len)

source

Bases: Gemma3Config

Text configuration for Gemma 4 models.

Defaults are from gemma-4-31b-it.

The following parameters are ignored:
  • routed_layer_pattern
    • stream_and_decode_in_f32

    Parameters:

    • vocab_size (int)
    • hidden_size (int)
    • intermediate_size (int)
    • num_hidden_layers (int)
    • num_attention_heads (int)
    • num_key_value_heads (int)
    • head_dim (int)
    • hidden_activation (str)
    • max_position_embeddings (int)
    • rms_norm_eps (float)
    • rope_theta (float)
    • attention_bias (bool)
    • query_pre_attn_scalar (float | None)
    • sliding_window (int)
    • final_logit_softcapping (float | None)
    • attn_logit_softcapping (int | None)
    • rope_scaling (LinearScalingParams | None)
    • rope_local_base_freq (float)
    • sliding_window_pattern (int)
    • dtype (DType)
    • devices (list[DeviceRef])
    • interleaved_rope_weights (bool)
    • return_logits (ReturnLogits)
    • kv_params (KVCacheParams)
    • tie_word_embeddings (bool)
    • quant_config (QuantConfig | None)
    • vocab_size_per_layer_input (int)
    • hidden_size_per_layer_input (int)
    • num_global_key_value_heads (int)
    • global_head_dim (int)
    • attention_k_eq_v (bool)
    • num_kv_shared_layers (int)
    • enable_moe_block (bool)
    • use_double_wide_mlp (bool)
    • num_experts (int)
    • top_k_experts (int)
    • moe_intermediate_size (int)
    • global_rope_scaling (ProportionalScalingParams | None)
    • global_rope_theta (float)
    • sliding_window_rope_theta (float)
    • layer_types (list[str])
    • max_seq_len (int)

    attention_k_eq_v​

    attention_k_eq_v: bool = True

    source

    If the key and value projections are the same.

    When true, the checkpoint will not contain v_proj and v_norm weights.

    attn_logit_softcapping​

    attn_logit_softcapping: int | None = None

    source

    Scaling factor when applying tanh softcapping on the attention scores.

    calculate_max_seq_len()​

    static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)

    source

    Calculates the maximum sequence length for the model.

    Uses the max_length from the max.pipelines.config.PipelineConfig if provided, otherwise falls back to the max_position_embeddings from the HuggingFace configuration’s text config.

    Parameters:

    • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
    • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).
    • model_config (MAXModelConfig | None)

    Returns:

    The calculated maximum sequence length.

    Return type:

    int

    enable_moe_block​

    enable_moe_block: bool = False

    source

    If the model uses MOE.

    get_max_seq_len()​

    get_max_seq_len()

    source

    Returns the default maximum sequence length for the model.

    Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

    Return type:

    int

    global_head_dim​

    global_head_dim: int = 512

    source

    Head dimension used in full attention layers.

    global_rope_scaling​

    global_rope_scaling: ProportionalScalingParams | None = None

    source

    Scaling configuration for the RoPE embeddings used in global attention.

    global_rope_theta​

    global_rope_theta: float = 1000000.0

    source

    Rope theta used for the RoPE embeddings used in global attention.

    hidden_size_per_layer_input​

    hidden_size_per_layer_input: int = 0

    source

    Hidden size output of the per-layer input embedding. When this is 0, the per-layer input embedding is not used.

    initialize_from_config()​

    classmethod initialize_from_config(pipeline_config, huggingface_config)

    source

    Initialize Gemma4TextConfig from pipeline and HuggingFace configs.

    Parameters:

    • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
    • huggingface_config (AutoConfig) – HuggingFace text model configuration.

    Returns:

    An initialized Gemma4TextConfig instance.

    Return type:

    Self

    layer_types​

    layer_types: list[str]

    source

    max_seq_len​

    max_seq_len: int

    source

    actual max seq length determined by calculate_max_seq_len

    moe_intermediate_size​

    moe_intermediate_size: int = 0

    source

    Hidden dimension of each MoE expert’s feed-forward block.

    num_experts​

    num_experts: int = 0

    source

    Total number of MoE experts.

    num_global_key_value_heads​

    num_global_key_value_heads: int = 4

    source

    Number of key value heads used in full attention layers.

    num_kv_shared_layers​

    num_kv_shared_layers: int = 0

    source

    An optimization used in smaller models to share the kv cache across layers.

    query_pre_attn_scalar​

    query_pre_attn_scalar: float | None = None

    source

    Scaling factor used on the attention scores.

    rope_scaling​

    property rope_scaling: ProportionalScalingParams | None

    source

    rope_theta​

    property rope_theta: float

    source

    sliding_window_rope_theta​

    sliding_window_rope_theta: float = 10000.0

    source

    Rope theta used for the RoPE embeddings used in sliding window attention.

    top_k_experts​

    top_k_experts: int = 0

    source

    Number of experts selected per token by the router.

    use_double_wide_mlp​

    use_double_wide_mlp: bool = False

    source

    If the model uses a double wide MLP.

    vocab_size_per_layer_input​

    vocab_size_per_layer_input: int = 262144

    source

    Vocab size used in the per-layer input embedding (used in smaller architectures).

    Gemma4VisionConfig​

    class max.pipelines.architectures.gemma4.Gemma4VisionConfig(hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, rms_norm_eps, max_position_embeddings, patch_size, position_embedding_size, pooling_kernel_size, standardize=False, attention_bias=False, attention_dropout=0.0, use_bidirectional_attention='vision', layer_types=None, use_clipped_linears=False, rope_theta=100.0)

    source

    Bases: object

    Vision-specific configuration for Gemma 4 models.

    Parameters:

    • hidden_size (int)
    • intermediate_size (int)
    • num_hidden_layers (int)
    • num_attention_heads (int)
    • num_key_value_heads (int)
    • head_dim (int)
    • hidden_activation (str)
    • rms_norm_eps (float)
    • max_position_embeddings (int)
    • patch_size (int)
    • position_embedding_size (int)
    • pooling_kernel_size (int)
    • standardize (bool)
    • attention_bias (bool)
    • attention_dropout (float)
    • use_bidirectional_attention (str | None)
    • layer_types (list[str] | None)
    • use_clipped_linears (bool)
    • rope_theta (float)

    attention_bias​

    attention_bias: bool = False

    source

    Whether to use bias in attention projection layers.

    attention_dropout​

    attention_dropout: float = 0.0

    source

    The dropout ratio for the attention probabilities.

    head_dim​

    head_dim: int

    source

    Dimension of each attention head.

    hidden_activation​

    hidden_activation: str

    source

    The non-linear activation function in the encoder.

    hidden_size​

    hidden_size: int

    source

    Dimensionality of the encoder layers.

    initialize_from_config()​

    classmethod initialize_from_config(hf_vision_config)

    source

    Initialize Gemma4VisionConfig from a HuggingFace vision config.

    Parameters:

    hf_vision_config (AutoConfig) – The HuggingFace vision configuration object.

    Returns:

    An initialized Gemma4VisionConfig instance.

    Return type:

    Gemma4VisionConfig

    intermediate_size​

    intermediate_size: int

    source

    Dimension of the MLP representations.

    layer_types​

    layer_types: list[str] | None = None

    source

    Per-layer attention type specification (e.g. "full_attention").

    max_position_embeddings​

    max_position_embeddings: int

    source

    The maximum sequence length supported by position embeddings.

    num_attention_heads​

    num_attention_heads: int

    source

    Number of attention heads for each attention layer.

    num_hidden_layers​

    num_hidden_layers: int

    source

    Number of hidden layers in the vision Transformer encoder.

    num_key_value_heads​

    num_key_value_heads: int

    source

    Number of key-value heads for grouped-query attention.

    patch_size​

    patch_size: int

    source

    The size (resolution) of each patch.

    pooling_kernel_size​

    pooling_kernel_size: int

    source

    Kernel size for spatial pooling.

    position_embedding_size​

    position_embedding_size: int

    source

    Size of the position embedding table.

    rms_norm_eps​

    rms_norm_eps: float

    source

    The epsilon used by the RMS normalization layers.

    rope_theta​

    rope_theta: float = 100.0

    source

    standardize​

    standardize: bool = False

    source

    Whether to standardize the image features.

    use_bidirectional_attention​

    use_bidirectional_attention: str | None = 'vision'

    source

    Controls bidirectional attention scope. "all", "vision", or None.

    use_clipped_linears​

    use_clipped_linears: bool = False

    source

    Whether to use clipped linear layers.