Skip to main content

Python module

max.pipelines.architectures.gemma3_modulev3

Gemma 3 transformer architecture for text generation.

Gemma3Config​

class max.pipelines.architectures.gemma3_modulev3.Gemma3Config(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar, sliding_window, final_logit_softcapping, attn_logit_softcapping, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, mesh, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None)

source

Bases: ArchConfigWithKVCache

Represents the MAX Engine configuration for Gemma 3 models.

Contains parameters specific to the Gemma 3 architecture (typically extracted from HuggingFace configs), plus MAX-specific runtime settings and helpers.

Parameters:

attention_bias​

attention_bias: bool

source

Whether to use a bias in the query, key, value and output projection layers during self-attention.

attn_logit_softcapping​

attn_logit_softcapping: int | None

source

Scaling factor when applying tanh softcapping on the attention scores.

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the maximum sequence length for the model.

Uses the max_length from the max.pipelines.config.PipelineConfig if provided, otherwise falls back to the max_position_embeddings from the HuggingFace configuration’s text config.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The calculated maximum sequence length.

Return type:

int

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs the KV cache parameters from configuration objects.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).
  • n_devices – The number of devices the model will run on.
  • kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (max.pipelines.max_config.KVCacheConfig).
  • cache_dtype (DType) – The desired data type for the KV cache (max.dtype.DType).
  • pipeline_config (PipelineConfig)
  • devices (list[DeviceRef])

Returns:

The configured max.pipelines.kv_cache.KVCacheParams object.

Return type:

KVCacheParams

dtype​

dtype: DType

source

DType of the model weights and input.

final_logit_softcapping​

final_logit_softcapping: float | None

source

Scaling factor when applying tanh softcapping on the logits.

finalize()​

finalize(huggingface_config, state_dict, return_logits, quant_config=None)

source

Define parameters that can’t be determined just from the pipeline config.

This method sets fields that require introspection of the model weights (state_dict), such as tie_word_embeddings and quant_config.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace model configuration object.
  • state_dict (dict[str, WeightData]) – The model’s state dictionary containing weights.
  • return_logits (ReturnLogits) – Whether to return the last token, all tokens or a variable number of logits.
  • quant_config (QuantConfig | None) – Scaled quantization configuration (optional).

Return type:

None

get_kv_params()​

get_kv_params()

source

KV cache parameters to use when running the model.

Return type:

KVCacheParams

get_max_seq_len()​

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

get_num_layers()​

static get_num_layers(huggingface_config)

source

Retrieves the number of hidden layers from the HuggingFace configuration.

Parameters:

huggingface_config (AutoConfig) – The HuggingFace model configuration object (transformers.AutoConfig).

Returns:

The number of hidden layers specified in the configuration’s text config.

Return type:

int

head_dim​

head_dim: int

source

The attention head dimension.

hidden_activation​

hidden_activation: str

source

The non-linear activation function (function or string) in the decoder. Will default to β€œgelu_tanh” if not specified. β€œgelu_tanh” uses an approximation of the β€œgelu” activation function.

hidden_size​

hidden_size: int

source

Dimension of the hidden representations.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Initializes a Gemma3Config instance from pipeline and HuggingFace configuration.

This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configuration, without needing the state_dict. Fields that depend on the state_dict (like tie_word_embeddings, quant_config) should be set via the finalize() method.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration object.

Returns:

An initialized Gemma3Config instance.

Return type:

Self

interleaved_rope_weights​

interleaved_rope_weights: bool

source

True if the rope weights are in interleaved complex format.

intermediate_size​

intermediate_size: int

source

Dimension of the MLP representations.

kv_params​

kv_params: KVCacheParams

source

KV cache parameters.

max_position_embeddings​

max_position_embeddings: int

source

The maximum sequence length that this model might ever be used with.

mesh​

mesh: DeviceMesh

source

Device mesh to run the model with.

num_attention_heads​

num_attention_heads: int

source

Number of attention heads for each attention layer in the Transformer decoder.

num_hidden_layers​

num_hidden_layers: int

source

Number of hidden layers in the Transformer decoder.

num_key_value_heads​

num_key_value_heads: int

source

Number of key_value heads that should be used to implement Grouped Query Attention.

quant_config​

quant_config: QuantConfig | None = None

source

Scaled quantization configuration.

query_pre_attn_scalar​

query_pre_attn_scalar: float | None

source

Scaling factor used on the attention scores.

return_logits​

return_logits: ReturnLogits = 'last_token'

source

Whether to return the last token, all logits, or a variable number of logits.

rms_norm_eps​

rms_norm_eps: float

source

The epsilon used by the rms normalization layers.

rope_local_base_freq​

rope_local_base_freq: float

source

The base period of the RoPE embeddings for local attention.

rope_scaling​

rope_scaling: LinearScalingParams | None

source

Scaling configuration for the RoPE embeddings used in global attention.

rope_theta​

rope_theta: float

source

The base period of the RoPE embeddings.

sliding_window​

sliding_window: int

source

In the Gemma3 language model, every other layer uses sliding window attention. This is the size of the sliding window.

sliding_window_pattern​

sliding_window_pattern: int

source

Pattern for the sliding window attention.

tie_word_embeddings​

tie_word_embeddings: bool = False

source

Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.

vocab_size​

vocab_size: int

source

Vocabulary size of the Gemma3Text model.

Gemma3Inputs​

class max.pipelines.architectures.gemma3_modulev3.Gemma3Inputs(tokens, input_row_offsets, return_n_logits, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Gemma3 model (ModuleV3).

This class encapsulates the input tensors required for the Gemma3 model execution.

Parameters:

input_row_offsets​

input_row_offsets: Buffer

source

Tensor containing the offsets for each row in the ragged input sequence.

return_n_logits​

return_n_logits: Buffer

source

Number of logits to return.

tokens​

tokens: Buffer

source

Tensor containing the input token IDs.

Gemma3Model​

class max.pipelines.architectures.gemma3_modulev3.Gemma3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: LogProbabilitiesMixin, PipelineModelWithKVCache[TextContext]

A Gemma3 pipeline model for text generation using the ModuleV3 API.

This class integrates the Gemma3 architecture with the MAX Engine pipeline infrastructure using the V3 eager compilation API.

Parameters:

calculate_max_seq_len()​

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the optimal max sequence length for the model.

Models are expected to implement this method. The following example shows how to implement it for a Mistral model:

class MistralModel(PipelineModel):
    @classmethod
    def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
        try:
            return upper_bounded_default(
                upper_bound=huggingface_config.max_seq_len,
                default=pipeline_config.model.max_length,
            )
        except ValueError as e:
            raise ValueError(
                "Unable to infer max_length for Mistral, the provided "
                f"max_length ({pipeline_config.model.max_length}) exceeds the "
                f"model's max_seq_len ({huggingface_config.max_seq_len})."
            ) from e

Parameters:

  • pipeline_config (PipelineConfig) – Configuration for the pipeline.
  • huggingface_config (AutoConfig) – Hugging Face model configuration.

Returns:

The maximum sequence length to use.

Return type:

int

execute()​

execute(model_inputs)

source

Executes the Gemma3 model with the prepared inputs.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

get_kv_params()​

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Returns the KV cache params for the pipeline model.

Parameters:

Return type:

KVCacheParams

get_num_layers()​

classmethod get_num_layers(huggingface_config)

source

Parameters:

huggingface_config (AutoConfig)

Return type:

int

load_model()​

load_model()

source

Loads the compiled Gemma3 model using the ModuleV3 API.

Return type:

Callable[[…], Any]

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepares the initial inputs to be passed to execute().

The inputs and functionality can vary per model. For example, model inputs could include encoded tensors, unique IDs per tensor when using a KV cache manager, and kv_cache_inputs (or None if the model does not use KV cache). This method typically batches encoded tensors, claims a KV cache slot if needed, and returns the inputs and caches.

Parameters:

Return type:

ModelInputs

prepare_next_token_inputs()​

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the secondary inputs to be passed to execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

Parameters:

Return type:

ModelInputs