Python module
max.pipelines.architectures.gemma3
Gemma 3 transformer architecture for text generation.
Gemma3Config
class max.pipelines.architectures.gemma3.Gemma3Config(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar, sliding_window, final_logit_softcapping, attn_logit_softcapping, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None)
Bases: ArchConfigWithKVCache
Represents the MAX Engine configuration for Gemma 3 models.
Contains parameters specific to the Gemma 3 architecture (typically extracted from HuggingFace configs), plus MAX-specific runtime settings and helpers.
-
Parameters:
-
- vocab_size (int)
- hidden_size (int)
- intermediate_size (int)
- num_hidden_layers (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- head_dim (int)
- hidden_activation (str)
- max_position_embeddings (int)
- rms_norm_eps (float)
- rope_theta (float)
- attention_bias (bool)
- query_pre_attn_scalar (float | None)
- sliding_window (int)
- final_logit_softcapping (float | None)
- attn_logit_softcapping (int | None)
- rope_scaling (LinearScalingParams | None)
- rope_local_base_freq (float)
- sliding_window_pattern (int)
- dtype (DType)
- devices (list[DeviceRef])
- interleaved_rope_weights (bool)
- return_logits (ReturnLogits)
- kv_params (KVCacheParams)
- tie_word_embeddings (bool)
- quant_config (QuantConfig | None)
attention_bias
attention_bias: bool
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attn_logit_softcapping
Scaling factor when applying tanh softcapping on the attention scores.
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the model.
Uses the max_length from the max.pipelines.config.PipelineConfig if provided,
otherwise falls back to the max_position_embeddings from the HuggingFace
configuration’s text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig).
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
construct_kv_params()
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Constructs the KV cache parameters from configuration objects.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig). - n_devices – The number of devices the model will run on.
- kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (
max.pipelines.max_config.KVCacheConfig). - cache_dtype (DType) – The desired data type for the KV cache (
max.dtype.DType). - pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
-
Returns:
-
The configured
max.pipelines.kv_cache.KVCacheParamsobject. -
Return type:
devices
Devices to run the model with.
dtype
dtype: DType
DType of the model weights and input.
final_logit_softcapping
Scaling factor when applying tanh softcapping on the logits.
finalize()
finalize(huggingface_config, state_dict, return_logits, quant_config=None)
Define parameters that can’t be determined just from the pipeline config.
This method sets fields that require introspection of the model weights (state_dict), such as tie_word_embeddings and quant_config.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object.
- state_dict (dict[str, WeightData]) – The model’s state dictionary containing weights.
- return_logits (ReturnLogits) – Whether to return the last token, all tokens or a variable number of logits.
- quant_config (QuantConfig | None) – Scaled quantization configuration (optional).
-
Return type:
-
None
get_kv_params()
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()
static get_num_layers(huggingface_config)
Retrieves the number of hidden layers from the HuggingFace configuration.
-
Parameters:
-
huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig). -
Returns:
-
The number of hidden layers specified in the configuration’s text config.
-
Return type:
head_dim
head_dim: int
The attention head dimension.
hidden_activation
hidden_activation: str
The non-linear activation function (function or string) in the decoder. Will default to “gelu_tanh” if not specified. “gelu_tanh” uses an approximation of the “gelu” activation function.
hidden_size
hidden_size: int
Dimension of the hidden representations.
initialize()
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The pipeline configuration.
- model_config (MAXModelConfig | None) – The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
initialize_from_config()
classmethod initialize_from_config(pipeline_config, huggingface_config)
Initializes a Gemma3Config instance from pipeline and HuggingFace configuration.
This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configuration, without needing the state_dict. Fields that depend on the state_dict (like tie_word_embeddings, quant_config) should be set via the finalize() method.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – The HuggingFace model configuration object.
-
Returns:
-
An initialized
Gemma3Configinstance. -
Return type:
interleaved_rope_weights
interleaved_rope_weights: bool
True if the rope weights are in interleaved complex format.
intermediate_size
intermediate_size: int
Dimension of the MLP representations.
kv_params
kv_params: KVCacheParams
KV cache parameters.
max_position_embeddings
max_position_embeddings: int
The maximum sequence length that this model might ever be used with.
num_attention_heads
num_attention_heads: int
Number of attention heads for each attention layer in the Transformer decoder.
num_hidden_layers
num_hidden_layers: int
Number of hidden layers in the Transformer decoder.
num_key_value_heads
num_key_value_heads: int
Number of key_value heads that should be used to implement Grouped Query Attention.
quant_config
quant_config: QuantConfig | None = None
Scaled quantization configuration.
query_pre_attn_scalar
Scaling factor used on the attention scores.
return_logits
return_logits: ReturnLogits = 'last_token'
Whether to return the last token, all logits, or a variable number of logits.
rms_norm_eps
rms_norm_eps: float
The epsilon used by the rms normalization layers.
rope_local_base_freq
rope_local_base_freq: float
The base period of the RoPE embeddings for local attention.
rope_scaling
rope_scaling: LinearScalingParams | None
Scaling configuration for the RoPE embeddings used in global attention.
rope_theta
rope_theta: float
The base period of the RoPE embeddings.
sliding_window
sliding_window: int
In the Gemma3 language model, every other layer uses sliding window attention. This is the size of the sliding window.
sliding_window_pattern
sliding_window_pattern: int
Pattern for the sliding window attention.
tie_word_embeddings
tie_word_embeddings: bool = False
Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.
vocab_size
vocab_size: int
Vocabulary size of the Gemma3Text model.
Gemma3Model
class max.pipelines.architectures.gemma3.Gemma3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: LogProbabilitiesMixin, AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextContext]
A Gemma 3 pipeline model for text generation.
This class integrates the Gemma 3 architecture with the MAX Engine pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The configuration settings for the entire pipeline.
- session (InferenceSession) – The MAX Engine inference session managing the runtime.
- devices (list[Device]) – A list of MAX Engine devices (
max.driver.Device) to run the model on. - kv_cache_config (KVCacheConfig) – Configuration settings for the Key-Value cache
(
max.pipelines.max_config.KVCacheConfig). - weights (Weights) – The model weights (
max.graph.weights.Weights). - adapter (WeightsAdapter | None) – An optional adapter to modify weights before loading
(
max.graph.weights.WeightsAdapter). - return_logits (ReturnLogits) – The number of top logits to return from the model execution.
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the Gemma 3 model.
Uses the max_length from the max.pipelines.config.PipelineConfig
if provided, otherwise falls back to the max_position_embeddings from
the HuggingFace configuration’s text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – The HuggingFace model configuration object
(
transformers.AutoConfig).
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
execute()
execute(model_inputs)
Executes the Gemma 3 model with the prepared inputs.
-
Parameters:
-
model_inputs (ModelInputs) – The prepared inputs for the model execution, typically including token IDs, attention masks/offsets, and KV cache inputs.
-
Returns:
-
An object containing the output logits from the model execution.
-
Return type:
get_kv_params()
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Gets the parameters required to configure the KV cache for Gemma 3.
Delegates to the Gemma3Config.construct_kv_params static method.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object
(
transformers.AutoConfig). - pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- devices (list[DeviceRef]) – The list of devices the model will run on.
- kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings
(
max.pipelines.max_config.KVCacheConfig). - cache_dtype (DType) – The desired data type for the KV cache
(
max.dtype.DType).
- huggingface_config (AutoConfig) – The HuggingFace model configuration object
(
-
Returns:
-
The configured
max.pipelines.kv_cache.KVCacheParamsobject. -
Return type:
get_num_layers()
classmethod get_num_layers(huggingface_config)
Gets the number of hidden layers from the HuggingFace configuration.
Delegates to the Gemma3Config.get_num_layers static method.
-
Parameters:
-
huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig). -
Returns:
-
The number of hidden layers.
-
Return type:
load_model()
load_model(session)
Loads the compiled Gemma 3 model into the MAX Engine session.
-
Parameters:
-
session (InferenceSession) – The MAX Engine inference session.
-
Returns:
-
The loaded MAX Engine model object.
-
Return type:
model
model: Model
The compiled and initialized MAX Engine model ready for inference.
prepare_initial_token_inputs()
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs for the first execution pass of the Gemma 3 model.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]]) – A sequence of sequences of
TextContextobjects representing the input prompts for each replica. - kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None) – Optional inputs required by the KV cache manager.
- return_n_logits (int)
- replica_batches (Sequence[Sequence[TextContext]]) – A sequence of sequences of
-
Returns:
-
The prepared
ModelInputsobject for the initial execution step. -
Return type:
prepare_next_token_inputs()
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the inputs for subsequent execution steps in a multi-step generation.
-
Parameters:
-
- next_tokens (Buffer) – The tensor containing the token IDs generated in the previous step.
- prev_model_inputs (ModelInputs) – The
ModelInputsused in the previous execution step.
-
Returns:
-
The prepared
ModelInputsobject for the next execution step. -
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!