Skip to main content

Python module

max.pipelines.architectures.gemma3

Gemma 3 transformer architecture for text generation.

Gemma3Config​

class max.pipelines.architectures.gemma3.Gemma3Config(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar, sliding_window, final_logit_softcapping, attn_logit_softcapping, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None)

source

Bases: ArchConfigWithKVCache

Represents the MAX Engine configuration for Gemma 3 models.

Contains parameters specific to the Gemma 3 architecture (typically extracted from HuggingFace configs), plus MAX-specific runtime settings and helpers.

Parameters:

attention_bias​

attention_bias: bool

source

Whether to use a bias in the query, key, value and output projection layers during self-attention.

attn_logit_softcapping​

attn_logit_softcapping: int | None

source

Scaling factor when applying tanh softcapping on the attention scores.

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the model.

Uses the max_length from the max.pipelines.config.PipelineConfig if provided, otherwise falls back to the max_position_embeddings from the HuggingFace configuration’s text config.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The calculated maximum sequence length.

Return type:

int

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs the KV cache parameters from configuration objects.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).
  • n_devices – The number of devices the model will run on.
  • kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (max.pipelines.max_config.KVCacheConfig).
  • cache_dtype (DType) – The desired data type for the KV cache (max.dtype.DType).
  • pipeline_config (PipelineConfig)
  • devices (list[DeviceRef])

Returns:

The configured max.pipelines.kv_cache.KVCacheParams object.

Return type:

KVCacheParams

devices​

devices: list[DeviceRef]

source

Devices to run the model with.

dtype​

dtype: DType

source

DType of the model weights and input.

final_logit_softcapping​

final_logit_softcapping: float | None

source

Scaling factor when applying tanh softcapping on the logits.

finalize()​

finalize(huggingface_config, state_dict, return_logits, quant_config=None)

source

Define parameters that can’t be determined just from the pipeline config.

This method sets fields that require introspection of the model weights (state_dict), such as tie_word_embeddings and quant_config.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object.
  • state_dict (dict[str, WeightData]) – The model’s state dictionary containing weights.
  • return_logits (ReturnLogits) – Whether to return the last token, all tokens or a variable number of logits.
  • quant_config (QuantConfig | None) – Scaled quantization configuration (optional).

Return type:

None

get_kv_params()​

get_kv_params()

source

KV cache parameters to use when running the model.

Return type:

KVCacheParams

get_max_seq_len()​

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

get_num_layers()​

static get_num_layers(huggingface_config)

source

Retrieves the number of hidden layers from the HuggingFace configuration.

Parameters:

huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The number of hidden layers specified in the configuration’s text config.

Return type:

int

head_dim​

head_dim: int

source

The attention head dimension.

hidden_activation​

hidden_activation: str

source

The non-linear activation function (function or string) in the decoder. Will default to β€œgelu_tanh” if not specified. β€œgelu_tanh” uses an approximation of the β€œgelu” activation function.

hidden_size​

hidden_size: int

source

Dimension of the hidden representations.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes a Gemma3Config instance from pipeline and HuggingFace configuration.

This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configuration, without needing the state_dict. Fields that depend on the state_dict (like tie_word_embeddings, quant_config) should be set via the finalize() method.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object.

Returns:

An initialized Gemma3Config instance.

Return type:

Self

interleaved_rope_weights​

interleaved_rope_weights: bool

source

True if the rope weights are in interleaved complex format.

intermediate_size​

intermediate_size: int

source

Dimension of the MLP representations.

kv_params​

kv_params: KVCacheParams

source

KV cache parameters.

max_position_embeddings​

max_position_embeddings: int

source

The maximum sequence length that this model might ever be used with.

num_attention_heads​

num_attention_heads: int

source

Number of attention heads for each attention layer in the Transformer decoder.

num_hidden_layers​

num_hidden_layers: int

source

Number of hidden layers in the Transformer decoder.

num_key_value_heads​

num_key_value_heads: int

source

Number of key_value heads that should be used to implement Grouped Query Attention.

quant_config​

quant_config: QuantConfig | None = None

source

Scaled quantization configuration.

query_pre_attn_scalar​

query_pre_attn_scalar: float | None

source

Scaling factor used on the attention scores.

return_logits​

return_logits: ReturnLogits = 'last_token'

source

Whether to return the last token, all logits, or a variable number of logits.

rms_norm_eps​

rms_norm_eps: float

source

The epsilon used by the rms normalization layers.

rope_local_base_freq​

rope_local_base_freq: float

source

The base period of the RoPE embeddings for local attention.

rope_scaling​

rope_scaling: LinearScalingParams | None

source

Scaling configuration for the RoPE embeddings used in global attention.

rope_theta​

rope_theta: float

source

The base period of the RoPE embeddings.

sliding_window​

sliding_window: int

source

In the Gemma3 language model, every other layer uses sliding window attention. This is the size of the sliding window.

sliding_window_pattern​

sliding_window_pattern: int

source

Pattern for the sliding window attention.

tie_word_embeddings​

tie_word_embeddings: bool = False

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

vocab_size​

vocab_size: int

source

Vocabulary size of the Gemma3Text model.

Gemma3Inputs​

class max.pipelines.architectures.gemma3.Gemma3Inputs(tokens, input_row_offsets, signal_buffers, return_n_logits, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Gemma3 model.

This class encapsulates the input tensors required for the Gemma3 model execution.

Parameters:

input_row_offsets​

input_row_offsets: Buffer

source

Tensor containing the offsets for each row in the ragged input sequence.

return_n_logits​

return_n_logits: Buffer

source

Number of logits to return.

signal_buffers​

signal_buffers: list[Buffer]

source

Device buffers used for synchronization in communication collectives.

tokens​

tokens: Buffer

source

Tensor containing the input token IDs.

Gemma3Model​

class max.pipelines.architectures.gemma3.Gemma3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: LogProbabilitiesMixin, AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextContext]

A Gemma 3 pipeline model for text generation.

This class integrates the Gemma 3 architecture with the MAX Engine pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.

Parameters:

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the Gemma 3 model.

Uses the max_length from the max.pipelines.config.PipelineConfig if provided, otherwise falls back to the max_position_embeddings from the HuggingFace configuration’s text config.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The calculated maximum sequence length.

Return type:

int

execute()​

execute(model_inputs)

source

Executes the Gemma 3 model with the prepared inputs.

Parameters:

model_inputs (ModelInputs) – The prepared inputs for the model execution, typically including token IDs, attention masks/offsets, and KV cache inputs.

Returns:

An object containing the output logits from the model execution.

Return type:

ModelOutputs

get_kv_params()​

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Gets the parameters required to configure the KV cache for Gemma 3.

Delegates to the Gemma3Config.construct_kv_params static method.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).
  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • devices (list[DeviceRef]) – The list of devices the model will run on.
  • kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (max.pipelines.max_config.KVCacheConfig).
  • cache_dtype (DType) – The desired data type for the KV cache (max.dtype.DType).

Returns:

The configured max.pipelines.kv_cache.KVCacheParams object.

Return type:

KVCacheParams

get_num_layers()​

classmethod get_num_layers(huggingface_config)

source

Gets the number of hidden layers from the HuggingFace configuration.

Delegates to the Gemma3Config.get_num_layers static method.

Parameters:

huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The number of hidden layers.

Return type:

int

load_model()​

load_model(session)

source

Loads the compiled Gemma 3 model into the MAX Engine session.

Parameters:

session (InferenceSession) – The MAX Engine inference session.

Returns:

The loaded MAX Engine model object.

Return type:

Model

model​

model: Model

source

The compiled and initialized MAX Engine model ready for inference.

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepares the initial inputs for the first execution pass of the Gemma 3 model.

Parameters:

  • replica_batches (Sequence[Sequence[TextContext]]) – A sequence of sequences of TextContext objects representing the input prompts for each replica.
  • kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None) – Optional inputs required by the KV cache manager.
  • return_n_logits (int)

Returns:

The prepared ModelInputs object for the initial execution step.

Return type:

ModelInputs

prepare_next_token_inputs()​

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the inputs for subsequent execution steps in a multi-step generation.

Parameters:

  • next_tokens (Buffer) – The tensor containing the token IDs generated in the previous step.
  • prev_model_inputs (ModelInputs) – The ModelInputs used in the previous execution step.

Returns:

The prepared ModelInputs object for the next execution step.

Return type:

ModelInputs