IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.gemma4

Gemma 4 vision-language architecture for multimodal text generation.

Gemma3MultiModalModelInputs​

class max.pipelines.architectures.gemma4.Gemma3MultiModalModelInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, images=None, video=None, combined_embeds=None, combined_indices=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Gemma3 multi modal model.

This class encapsulates the input tensors required for the Gemma3 multi modal model, for text and vision processing.

Parameters:

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for the language model ABI.

combined_embeds​

combined_embeds: list[Buffer] | None = None

source

combined_indices​

combined_indices: list[Buffer] | None = None

source

images​

images: ImageInputs | None = None

source

input_row_offsets​

input_row_offsets: ndarray[tuple[Any, ...], dtype[integer[Any]]] | list[Buffer]

source

return_n_logits​

return_n_logits: Buffer

source

signal_buffers​

signal_buffers: list[Buffer]

source

tokens​

tokens: ndarray[tuple[Any, ...], dtype[integer[Any]]] | Buffer

source

video​

video: VideoInputs | None = None

source

Gemma3_MultiModalModel​

class max.pipelines.architectures.gemma4.Gemma3_MultiModalModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[Gemma4Context]

Gemma 3 multimodal pipeline model for text generation.

This class integrates the Gemma 3 multimodal architecture with the MAX pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

estimate_activation_memory()​

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Estimates the activation memory required for model execution.

This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.

The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.

Parameters:

  • pipeline_config (PipelineConfig) – Pipeline configuration
  • huggingface_config (AutoConfig) – Hugging Face model configuration

Returns:

Estimated activation memory in bytes

Return type:

int

execute()​

execute(model_inputs)

source

Execute the vision model (if needed), then the language model.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

language_model​

language_model: Model

source

The compiled and initialized MAX Engine model ready for inference.

load_model()​

load_model(session)

source

Loads the compiled Gemma3 MultiModal models into the MAX Engine session.

Returns:

A tuple of (vision_model, language_model).

Parameters:

session (InferenceSession)

Return type:

tuple[Model, Model]

model​

property model: Model

source

Expose language model for graph capture/replay.

Only the language model is captured since vision runs during prefill only.

model_config_cls​

model_config_cls

source

alias of Gemma4ForConditionalGenerationConfig

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepare inputs for the first execution pass.

Parameters:

Return type:

ModelInputs

release()​

release(request_id)

source

Release vision encoder cache for a completed request.

Parameters:

request_id (RequestID)

Return type:

None

vision_model​

vision_model: Model

source

The compiled and initialized MAX Engine vision model ready for inference.

Gemma4ForConditionalGenerationConfig​

class max.pipelines.architectures.gemma4.Gemma4ForConditionalGenerationConfig(*, devices, dtype, unquantized_dtype=bfloat16, kv_params, image_token_index, video_token_index=262144, text_config, vision_config, tie_word_embeddings=False)

source

Bases: ArchConfigWithKVAndVisionCache

Base configuration for Gemma 4 multimodal models.

This is the top-level config that composes text and vision sub-configs. Model-specific parameters live in the respective sub-configs.

Parameters:

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the Gemma 4 model.

Parameters:

Return type:

int

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs KV cache parameters from the top-level HuggingFace config.

Parameters:

  • huggingface_config (AutoConfig) – The top-level HuggingFace config (with text_config).
  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • devices (list[DeviceRef]) – Target devices for the model.
  • kv_cache_config (KVCacheConfig) – KV cache configuration settings.
  • cache_dtype (DType) – Data type for the KV cache.

Returns:

Configured KV cache parameters.

Return type:

MultiKVCacheParams

devices​

devices: list[DeviceRef]

source

Devices to run the model with.

dtype​

dtype: DType

source

DType of the model weights and input.

estimate_vision_cache_entry_bytes()​

static estimate_vision_cache_entry_bytes(huggingface_config)

source

Estimate per-entry bytes for the vision encoder cache.

Worst-case tokens per image is position_embedding_size / pooling_kernel_sizeΒ², stored at the text hidden size in bfloat16.

Parameters:

huggingface_config (AutoConfig)

Return type:

int

finalize()​

finalize(huggingface_config, state_dict, return_logits)

source

Finalize with state_dict-dependent fields.

Parses quantization config from the weights and finalizes the text sub-config.

Parameters:

  • huggingface_config (AutoConfig) – HuggingFace model configuration.
  • state_dict (dict[str, WeightData]) – Model weights dictionary.
  • return_logits (ReturnLogits) – Return logits configuration.

Return type:

None

get_kv_params()​

get_kv_params()

source

Returns the KV cache parameters.

Return type:

MultiKVCacheParams

get_max_seq_len()​

get_max_seq_len()

source

Returns the maximum sequence length from the embedded text config.

Return type:

int

image_token_index​

image_token_index: int

source

The image token index to encode the image prompt.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initializes from pipeline configuration.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • model_config (MAXModelConfig | None) – Optional model config override.

Returns:

An initialized config instance.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes from pipeline and HuggingFace configs.

Fields that depend on the state_dict should be set via finalize().

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – Top-level HuggingFace model configuration.

Returns:

A config instance ready for finalization.

Return type:

Self

kv_params​

kv_params: MultiKVCacheParams

source

KV cache parameters.

text_config​

text_config: Gemma4TextConfig

source

The config object of the text backbone.

tie_word_embeddings​

tie_word_embeddings: bool = False

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

unquantized_dtype​

unquantized_dtype: DType = 80

source

DType of unquantized weights.

video_token_index​

video_token_index: int = 262144

source

The video token index to encode the video prompt.

vision_config​

vision_config: Gemma4VisionConfig

source

The config object of the vision encoder.

Gemma4ReasoningParser​

class max.pipelines.architectures.gemma4.Gemma4ReasoningParser(channel_start_token_id, channel_end_token_id, tool_call_start_token_id=None, think_token_id=None)

source

Bases: ReasoningParser

Gemma 4 reasoning parser for <|channel>…``<channel|>`` sections.

When thinking is enabled, the chat template injects a <|think|> token in the system message. The model then wraps reasoning output in <|channel>thought\n...\n<channel|> blocks. This parser identifies those blocks at the token-ID level.

Reasoning may end implicitly when a tool call begins (<|tool_call>). The tool-call token is not consumed as a delimiter β€” it stays in the content region for downstream tool parsing.

Parameters:

  • channel_start_token_id (int)
  • channel_end_token_id (int)
  • tool_call_start_token_id (int | None)
  • think_token_id (int | None)

from_tokenizer()​

async classmethod from_tokenizer(tokenizer)

source

Construct a reasoning parser from a tokenizer.

Parameters:

tokenizer (PipelineTokenizer[Any, Any, Any])

Return type:

Gemma4ReasoningParser

reasoning_end_token_id()​

async classmethod reasoning_end_token_id(tokenizer)

source

Returns the <channel|> token id.

Parameters:

tokenizer (PipelineTokenizer[Any, Any, Any])

Return type:

int | None

reasoning_prefix​

reasoning_prefix = 'thought\n'

source

reset()​

reset()

source

Resets per-request state.

Called at the start of each request to clear any internal state accumulated during a prior request.

Return type:

None

stream()​

stream(delta_token_ids, is_currently_reasoning=True)

source

Identifies a reasoning span within a streaming delta chunk.

Returns a ParsedReasoningDelta containing:

  • span: a ReasoningSpan with two index pairs into delta_token_ids β€” reasoning (content only) and reasoning_with_delimiters (includes boundary tokens).
  • is_still_reasoning: True when no end delimiter was found in this chunk and the chunk contained a reasoning section.
  • reasoning_text_formatter: callback that strips the "thought\n" prefix from decoded reasoning text.

When is_currently_reasoning=False and no <|channel> start delimiter appears in the chunk, the parser returns an empty reasoning span β€” Gemma 4 emits <|channel>thought\n...<channel|> even when enable_thinking is off, so callers should pass every chunk through here and let the parser dynamically detect mid-stream reasoning sections (mirroring vLLM’s behavior).

Parameters:

Return type:

ParsedReasoningDelta

will_reason_after_prompt()​

will_reason_after_prompt(prompt_token_ids)

source

Predicts whether the model will emit reasoning after this prompt.

Gemma 4 enables thinking by injecting <|think|> in the system message. When present, every assistant turn opens a <|channel>thought\n... block. Checking for <|think|> is the right signal β€” channel delimiters from prior turns are irrelevant because a new thinking block always starts.

Parameters:

prompt_token_ids (Sequence[int])

Return type:

bool

Gemma4TextConfig​

class max.pipelines.architectures.gemma4.Gemma4TextConfig(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar=None, sliding_window, final_logit_softcapping, attn_logit_softcapping=None, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None, vocab_size_per_layer_input=262144, hidden_size_per_layer_input=0, num_global_key_value_heads=4, global_head_dim=512, attention_k_eq_v=True, num_kv_shared_layers=0, enable_moe_block=False, use_double_wide_mlp=False, num_experts=0, top_k_experts=0, moe_intermediate_size=0, global_rope_scaling=None, global_rope_theta=1000000.0, sliding_window_rope_theta=10000.0, layer_types, max_seq_len)

source

Bases: Gemma3Config

Text configuration for Gemma 4 models.

Defaults are from gemma-4-31b-it.

The following parameters are ignored:
  • routed_layer_pattern
    • stream_and_decode_in_f32

    Parameters:

    • vocab_size (int)
    • hidden_size (int)
    • intermediate_size (int)
    • num_hidden_layers (int)
    • num_attention_heads (int)
    • num_key_value_heads (int)
    • head_dim (int)
    • hidden_activation (str)
    • max_position_embeddings (int)
    • rms_norm_eps (float)
    • rope_theta (float)
    • attention_bias (bool)
    • query_pre_attn_scalar (float | None)
    • sliding_window (int)
    • final_logit_softcapping (float | None)
    • attn_logit_softcapping (int | None)
    • rope_scaling (LinearScalingParams | None)
    • rope_local_base_freq (float)
    • sliding_window_pattern (int)
    • dtype (DType)
    • devices (list[DeviceRef])
    • interleaved_rope_weights (bool)
    • return_logits (ReturnLogits)
    • kv_params (KVCacheParams)
    • tie_word_embeddings (bool)
    • quant_config (QuantConfig | None)
    • vocab_size_per_layer_input (int)
    • hidden_size_per_layer_input (int)
    • num_global_key_value_heads (int)
    • global_head_dim (int)
    • attention_k_eq_v (bool)
    • num_kv_shared_layers (int)
    • enable_moe_block (bool)
    • use_double_wide_mlp (bool)
    • num_experts (int)
    • top_k_experts (int)
    • moe_intermediate_size (int)
    • global_rope_scaling (ProportionalScalingParams | None)
    • global_rope_theta (float)
    • sliding_window_rope_theta (float)
    • layer_types (list[str])
    • max_seq_len (int)

    attention_k_eq_v​

    attention_k_eq_v: bool = True

    source

    If the key and value projections are the same.

    When true, the checkpoint will not contain v_proj and v_norm weights.

    attn_logit_softcapping​

    attn_logit_softcapping: int | None = None

    source

    Scaling factor when applying tanh softcapping on the attention scores.

    calculate_max_seq_len()​

    classmethod calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)

    source

    Uses max_length when set, else max_position_embeddings.

    Parameters:

    Return type:

    int

    enable_moe_block​

    enable_moe_block: bool = False

    source

    If the model uses MOE.

    get_max_seq_len()​

    get_max_seq_len()

    source

    Returns the resolved maximum sequence length stored on the config.

    Return type:

    int

    global_head_dim​

    global_head_dim: int = 512

    source

    Head dimension used in full attention layers.

    global_rope_scaling​

    global_rope_scaling: ProportionalScalingParams | None = None

    source

    Scaling configuration for the RoPE embeddings used in global attention.

    global_rope_theta​

    global_rope_theta: float = 1000000.0

    source

    Rope theta used for the RoPE embeddings used in global attention.

    hidden_size_per_layer_input​

    hidden_size_per_layer_input: int = 0

    source

    Hidden size output of the per-layer input embedding. When this is 0, the per-layer input embedding is not used.

    initialize_from_config()​

    classmethod initialize_from_config(pipeline_config, huggingface_config)

    source

    Initialize Gemma4TextConfig from pipeline and HuggingFace configs.

    Parameters:

    • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
    • huggingface_config (AutoConfig) – HuggingFace text model configuration.

    Returns:

    An initialized Gemma4TextConfig instance.

    Return type:

    Self

    layer_types​

    layer_types: list[str]

    source

    max_seq_len​

    max_seq_len: int

    source

    actual max seq length determined by calculate_max_seq_len

    moe_intermediate_size​

    moe_intermediate_size: int = 0

    source

    Hidden dimension of each MoE expert’s feed-forward block.

    num_experts​

    num_experts: int = 0

    source

    Total number of MoE experts.

    num_global_key_value_heads​

    num_global_key_value_heads: int = 4

    source

    Number of key value heads used in full attention layers.

    num_kv_shared_layers​

    num_kv_shared_layers: int = 0

    source

    An optimization used in smaller models to share the kv cache across layers.

    query_pre_attn_scalar​

    query_pre_attn_scalar: float | None = None

    source

    Scaling factor used on the attention scores.

    rope_scaling​

    property rope_scaling: ProportionalScalingParams | None

    source

    rope_theta​

    property rope_theta: float

    source

    sliding_window_rope_theta​

    sliding_window_rope_theta: float = 10000.0

    source

    Rope theta used for the RoPE embeddings used in sliding window attention.

    top_k_experts​

    top_k_experts: int = 0

    source

    Number of experts selected per token by the router.

    use_double_wide_mlp​

    use_double_wide_mlp: bool = False

    source

    If the model uses a double wide MLP.

    vocab_size_per_layer_input​

    vocab_size_per_layer_input: int = 262144

    source

    Vocab size used in the per-layer input embedding (used in smaller architectures).

    Gemma4ToolParser​

    class max.pipelines.architectures.gemma4.Gemma4ToolParser

    source

    Bases: StructuralTagToolParser

    Gemma 4 tool parser using flat <|tool_call> … <tool_call|> pairs.

    Uses the flat (no-section-wrapper) mode of StructuralTagToolParser: only CALL_BEGIN/CALL_END are set. Arguments are emitted atomically (withheld until the close marker) because Gemma4’s <|"|> string delimiters make incremental JSON conversion non-monotonic.

    CALL_BEGIN​

    CALL_BEGIN: ClassVar[str] = '<|tool_call>'

    source

    CALL_END​

    CALL_END: ClassVar[str] = '<tool_call|>'

    source

    generate_tool_call_grammar()​

    static generate_tool_call_grammar(response_format_schema=None, tools=None, tokenizer=None, **kwargs)

    source

    Generates a Lark grammar for constrained decoding of Gemma4 tool calls.

    Parameters:

    Return type:

    str

    Gemma4VisionConfig​

    class max.pipelines.architectures.gemma4.Gemma4VisionConfig(hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, rms_norm_eps, max_position_embeddings, patch_size, position_embedding_size, pooling_kernel_size, standardize=False, attention_bias=False, attention_dropout=0.0, use_bidirectional_attention='vision', layer_types=None, use_clipped_linears=False, rope_theta=100.0)

    source

    Bases: object

    Vision-specific configuration for Gemma 4 models.

    Parameters:

    • hidden_size (int)
    • intermediate_size (int)
    • num_hidden_layers (int)
    • num_attention_heads (int)
    • num_key_value_heads (int)
    • head_dim (int)
    • hidden_activation (str)
    • rms_norm_eps (float)
    • max_position_embeddings (int)
    • patch_size (int)
    • position_embedding_size (int)
    • pooling_kernel_size (int)
    • standardize (bool)
    • attention_bias (bool)
    • attention_dropout (float)
    • use_bidirectional_attention (str | None)
    • layer_types (list[str] | None)
    • use_clipped_linears (bool)
    • rope_theta (float)

    attention_bias​

    attention_bias: bool = False

    source

    Whether to use bias in attention projection layers.

    attention_dropout​

    attention_dropout: float = 0.0

    source

    The dropout ratio for the attention probabilities.

    head_dim​

    head_dim: int

    source

    Dimension of each attention head.

    hidden_activation​

    hidden_activation: str

    source

    The non-linear activation function in the encoder.

    hidden_size​

    hidden_size: int

    source

    Dimensionality of the encoder layers.

    initialize_from_config()​

    classmethod initialize_from_config(hf_vision_config)

    source

    Initialize Gemma4VisionConfig from a HuggingFace vision config.

    Parameters:

    hf_vision_config (AutoConfig) – The HuggingFace vision configuration object.

    Returns:

    An initialized Gemma4VisionConfig instance.

    Return type:

    Gemma4VisionConfig

    intermediate_size​

    intermediate_size: int

    source

    Dimension of the MLP representations.

    layer_types​

    layer_types: list[str] | None = None

    source

    Per-layer attention type specification (e.g. "full_attention").

    max_position_embeddings​

    max_position_embeddings: int

    source

    The maximum sequence length supported by position embeddings.

    num_attention_heads​

    num_attention_heads: int

    source

    Number of attention heads for each attention layer.

    num_hidden_layers​

    num_hidden_layers: int

    source

    Number of hidden layers in the vision Transformer encoder.

    num_key_value_heads​

    num_key_value_heads: int

    source

    Number of key-value heads for grouped-query attention.

    patch_size​

    patch_size: int

    source

    The size (resolution) of each patch.

    pooling_kernel_size​

    pooling_kernel_size: int

    source

    Kernel size for spatial pooling.

    position_embedding_size​

    position_embedding_size: int

    source

    Size of the position embedding table.

    rms_norm_eps​

    rms_norm_eps: float

    source

    The epsilon used by the RMS normalization layers.

    rope_theta​

    rope_theta: float = 100.0

    source

    standardize​

    standardize: bool = False

    source

    Whether to standardize the image features.

    use_bidirectional_attention​

    use_bidirectional_attention: str | None = 'vision'

    source

    Controls bidirectional attention scope. "all", "vision", or None.

    use_clipped_linears​

    use_clipped_linears: bool = False

    source

    Whether to use clipped linear layers.