Skip to main content

Python module

max.pipelines.architectures.gemma3

Gemma 3 transformer architecture for text generation.

Gemma3Config

class max.pipelines.architectures.gemma3.Gemma3Config(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar, sliding_window, final_logit_softcapping, attn_logit_softcapping, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None)

source

Bases: ArchConfigWithKVCache

Represents the MAX Engine configuration for Gemma 3 models.

Contains parameters specific to the Gemma 3 architecture (typically extracted from HuggingFace configs), plus MAX-specific runtime settings and helpers.

Parameters:

attention_bias

attention_bias: bool

source

Whether to use a bias in the query, key, value and output projection layers during self-attention.

attn_logit_softcapping

attn_logit_softcapping: int | None

source

Scaling factor when applying tanh softcapping on the attention scores.

calculate_max_seq_len()

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the model.

Uses the max_length from the max.pipelines.config.PipelineConfig if provided, otherwise falls back to the max_position_embeddings from the HuggingFace configuration’s text config.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The calculated maximum sequence length.

Return type:

int

construct_kv_params()

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs the KV cache parameters from configuration objects.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).
  • n_devices – The number of devices the model will run on.
  • kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (max.pipelines.max_config.KVCacheConfig).
  • cache_dtype (DType) – The desired data type for the KV cache (max.dtype.DType).
  • pipeline_config (PipelineConfig)
  • devices (list[DeviceRef])

Returns:

The configured max.pipelines.kv_cache.KVCacheParams object.

Return type:

KVCacheParams

devices

devices: list[DeviceRef]

source

Devices to run the model with.

dtype

dtype: DType

source

DType of the model weights and input.

final_logit_softcapping

final_logit_softcapping: float | None

source

Scaling factor when applying tanh softcapping on the logits.

finalize()

finalize(huggingface_config, state_dict, return_logits, quant_config=None)

source

Define parameters that can’t be determined just from the pipeline config.

This method sets fields that require introspection of the model weights (state_dict), such as tie_word_embeddings and quant_config.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object.
  • state_dict (dict[str, WeightData]) – The model’s state dictionary containing weights.
  • return_logits (ReturnLogits) – Whether to return the last token, all tokens or a variable number of logits.
  • quant_config (QuantConfig | None) – Scaled quantization configuration (optional).

Return type:

None

get_kv_params()

get_kv_params()

source

KV cache parameters to use when running the model.

Return type:

KVCacheParams

get_max_seq_len()

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

get_num_layers()

static get_num_layers(huggingface_config)

source

Retrieves the number of hidden layers from the HuggingFace configuration.

Parameters:

huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The number of hidden layers specified in the configuration’s text config.

Return type:

int

head_dim

head_dim: int

source

The attention head dimension.

hidden_activation

hidden_activation: str

source

The non-linear activation function (function or string) in the decoder. Will default to “gelu_tanh” if not specified. “gelu_tanh” uses an approximation of the “gelu” activation function.

hidden_size

hidden_size: int

source

Dimension of the hidden representations.

initialize()

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

initialize_from_config()

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes a Gemma3Config instance from pipeline and HuggingFace configuration.

This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configuration, without needing the state_dict. Fields that depend on the state_dict (like tie_word_embeddings, quant_config) should be set via the finalize() method.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object.

Returns:

An initialized Gemma3Config instance.

Return type:

Self

interleaved_rope_weights

interleaved_rope_weights: bool

source

True if the rope weights are in interleaved complex format.

intermediate_size

intermediate_size: int

source

Dimension of the MLP representations.

kv_params

kv_params: KVCacheParams

source

KV cache parameters.

max_position_embeddings

max_position_embeddings: int

source

The maximum sequence length that this model might ever be used with.

num_attention_heads

num_attention_heads: int

source

Number of attention heads for each attention layer in the Transformer decoder.

num_hidden_layers

num_hidden_layers: int

source

Number of hidden layers in the Transformer decoder.

num_key_value_heads

num_key_value_heads: int

source

Number of key_value heads that should be used to implement Grouped Query Attention.

quant_config

quant_config: QuantConfig | None = None

source

Scaled quantization configuration.

query_pre_attn_scalar

query_pre_attn_scalar: float | None

source

Scaling factor used on the attention scores.

return_logits

return_logits: ReturnLogits = 'last_token'

source

Whether to return the last token, all logits, or a variable number of logits.

rms_norm_eps

rms_norm_eps: float

source

The epsilon used by the rms normalization layers.

rope_local_base_freq

rope_local_base_freq: float

source

The base period of the RoPE embeddings for local attention.

rope_scaling

rope_scaling: LinearScalingParams | None

source

Scaling configuration for the RoPE embeddings used in global attention.

rope_theta

rope_theta: float

source

The base period of the RoPE embeddings.

sliding_window

sliding_window: int

source

In the Gemma3 language model, every other layer uses sliding window attention. This is the size of the sliding window.

sliding_window_pattern

sliding_window_pattern: int

source

Pattern for the sliding window attention.

tie_word_embeddings

tie_word_embeddings: bool = False

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

vocab_size

vocab_size: int

source

Vocabulary size of the Gemma3Text model.

Gemma3Model

class max.pipelines.architectures.gemma3.Gemma3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: LogProbabilitiesMixin, AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextContext]

A Gemma 3 pipeline model for text generation.

This class integrates the Gemma 3 architecture with the MAX Engine pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

calculate_max_seq_len()

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the Gemma 3 model.

Uses the max_length from the max.pipelines.config.PipelineConfig if provided, otherwise falls back to the max_position_embeddings from the HuggingFace configuration’s text config.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The calculated maximum sequence length.

Return type:

int

execute()

execute(model_inputs)

source

Executes the Gemma 3 model with the prepared inputs.

Parameters:

model_inputs (ModelInputs) – The prepared inputs for the model execution, typically including token IDs, attention masks/offsets, and KV cache inputs.

Returns:

An object containing the output logits from the model execution.

Return type:

ModelOutputs

get_kv_params()

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Gets the parameters required to configure the KV cache for Gemma 3.

Delegates to the Gemma3Config.construct_kv_params static method.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).
  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • devices (list[DeviceRef]) – The list of devices the model will run on.
  • kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (max.pipelines.max_config.KVCacheConfig).
  • cache_dtype (DType) – The desired data type for the KV cache (max.dtype.DType).

Returns:

The configured max.pipelines.kv_cache.KVCacheParams object.

Return type:

KVCacheParams

get_num_layers()

classmethod get_num_layers(huggingface_config)

source

Gets the number of hidden layers from the HuggingFace configuration.

Delegates to the Gemma3Config.get_num_layers static method.

Parameters:

huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The number of hidden layers.

Return type:

int

load_model()

load_model(session)

source

Loads the compiled Gemma 3 model into the MAX Engine session.

Parameters:

session (InferenceSession) – The MAX Engine inference session.

Returns:

The loaded MAX Engine model object.

Return type:

Model

model

model: Model

source

The compiled and initialized MAX Engine model ready for inference.

prepare_initial_token_inputs()

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepares the initial inputs for the first execution pass of the Gemma 3 model.

Parameters:

  • replica_batches (Sequence[Sequence[TextContext]]) – A sequence of sequences of TextContext objects representing the input prompts for each replica.
  • kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None) – Optional inputs required by the KV cache manager.
  • return_n_logits (int)

Returns:

The prepared ModelInputs object for the initial execution step.

Return type:

ModelInputs

prepare_next_token_inputs()

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the inputs for subsequent execution steps in a multi-step generation.

Parameters:

  • next_tokens (Buffer) – The tensor containing the token IDs generated in the previous step.
  • prev_model_inputs (ModelInputs) – The ModelInputs used in the previous execution step.

Returns:

The prepared ModelInputs object for the next execution step.

Return type:

ModelInputs