Python module
max.pipelines.architectures.gemma3_modulev3
Gemma 3 transformer architecture for text generation.
Gemma3Config
class max.pipelines.architectures.gemma3_modulev3.Gemma3Config(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar, sliding_window, final_logit_softcapping, attn_logit_softcapping, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None)
Bases: ArchConfigWithKVCache
Represents the MAX Engine configuration for Gemma 3 models.
Contains parameters specific to the Gemma 3 architecture (typically extracted from HuggingFace configs), plus MAX-specific runtime settings and helpers.
-
Parameters:
-
- vocab_size (int)
- hidden_size (int)
- intermediate_size (int)
- num_hidden_layers (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- head_dim (int)
- hidden_activation (str)
- max_position_embeddings (int)
- rms_norm_eps (float)
- rope_theta (float)
- attention_bias (bool)
- query_pre_attn_scalar (float | None)
- sliding_window (int)
- final_logit_softcapping (float | None)
- attn_logit_softcapping (int | None)
- rope_scaling (LinearScalingParams | None)
- rope_local_base_freq (float)
- sliding_window_pattern (int)
- dtype (DType)
- devices (list[DeviceRef])
- interleaved_rope_weights (bool)
- return_logits (ReturnLogits)
- kv_params (KVCacheParams)
- tie_word_embeddings (bool)
- quant_config (QuantConfig | None)
attention_bias
attention_bias: bool
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attn_logit_softcapping
Scaling factor when applying tanh softcapping on the attention scores.
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the model.
Uses the max_length from the max.pipelines.config.PipelineConfig if provided,
otherwise falls back to the max_position_embeddings from the HuggingFace
configuration’s text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig).
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
construct_kv_params()
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Constructs the KV cache parameters from configuration objects.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig). - n_devices – The number of devices the model will run on.
- kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (
max.pipelines.max_config.KVCacheConfig). - cache_dtype (DType) – The desired data type for the KV cache (
max.dtype.DType). - pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
-
Returns:
-
The configured
max.pipelines.kv_cache.KVCacheParamsobject. -
Return type:
devices
Devices to run the model with.
dtype
dtype: DType
DType of the model weights and input.
final_logit_softcapping
Scaling factor when applying tanh softcapping on the logits.
finalize()
finalize(huggingface_config, state_dict, return_logits, quant_config=None)
Define parameters that can’t be determined just from the pipeline config.
This method sets fields that require introspection of the model weights (state_dict), such as tie_word_embeddings and quant_config.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object.
- state_dict (dict[str, WeightData]) – The model’s state dictionary containing weights.
- return_logits (ReturnLogits) – Whether to return the last token, all tokens or a variable number of logits.
- quant_config (QuantConfig | None) – Scaled quantization configuration (optional).
-
Return type:
-
None
get_kv_params()
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()
static get_num_layers(huggingface_config)
Retrieves the number of hidden layers from the HuggingFace configuration.
-
Parameters:
-
huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig). -
Returns:
-
The number of hidden layers specified in the configuration’s text config.
-
Return type:
head_dim
head_dim: int
The attention head dimension.
hidden_activation
hidden_activation: str
The non-linear activation function (function or string) in the decoder. Will default to “gelu_tanh” if not specified. “gelu_tanh” uses an approximation of the “gelu” activation function.
hidden_size
hidden_size: int
Dimension of the hidden representations.
initialize()
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The pipeline configuration.
- model_config (MAXModelConfig | None) – The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
initialize_from_config()
classmethod initialize_from_config(pipeline_config, huggingface_config)
Initializes a Gemma3Config instance from pipeline and HuggingFace configuration.
This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configuration, without needing the state_dict. Fields that depend on the state_dict (like tie_word_embeddings, quant_config) should be set via the finalize() method.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – The HuggingFace model configuration object.
-
Returns:
-
An initialized
Gemma3Configinstance. -
Return type:
interleaved_rope_weights
interleaved_rope_weights: bool
True if the rope weights are in interleaved complex format.
intermediate_size
intermediate_size: int
Dimension of the MLP representations.
kv_params
kv_params: KVCacheParams
KV cache parameters.
max_position_embeddings
max_position_embeddings: int
The maximum sequence length that this model might ever be used with.
num_attention_heads
num_attention_heads: int
Number of attention heads for each attention layer in the Transformer decoder.
num_hidden_layers
num_hidden_layers: int
Number of hidden layers in the Transformer decoder.
num_key_value_heads
num_key_value_heads: int
Number of key_value heads that should be used to implement Grouped Query Attention.
quant_config
quant_config: QuantConfig | None = None
Scaled quantization configuration.
query_pre_attn_scalar
Scaling factor used on the attention scores.
return_logits
return_logits: ReturnLogits = 'last_token'
Whether to return the last token, all logits, or a variable number of logits.
rms_norm_eps
rms_norm_eps: float
The epsilon used by the rms normalization layers.
rope_local_base_freq
rope_local_base_freq: float
The base period of the RoPE embeddings for local attention.
rope_scaling
rope_scaling: LinearScalingParams | None
Scaling configuration for the RoPE embeddings used in global attention.
rope_theta
rope_theta: float
The base period of the RoPE embeddings.
sliding_window
sliding_window: int
In the Gemma3 language model, every other layer uses sliding window attention. This is the size of the sliding window.
sliding_window_pattern
sliding_window_pattern: int
Pattern for the sliding window attention.
tie_word_embeddings
tie_word_embeddings: bool = False
Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.
vocab_size
vocab_size: int
Vocabulary size of the Gemma3Text model.
Gemma3Model
class max.pipelines.architectures.gemma3_modulev3.Gemma3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: LogProbabilitiesMixin, PipelineModelWithKVCache[TextContext]
A Gemma3 pipeline model for text generation using the ModuleV3 API.
This class integrates the Gemma3 architecture with the MAX Engine pipeline infrastructure using the V3 eager compilation API.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example shows how to implement it for a Mistral model:
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.model.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.model.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e-
Parameters:
-
- pipeline_config (PipelineConfig) – Configuration for the pipeline.
- huggingface_config (AutoConfig) – Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
execute()
execute(model_inputs)
Executes the Gemma3 model with the prepared inputs.
-
Parameters:
-
model_inputs (ModelInputs)
-
Return type:
get_kv_params()
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Returns the KV cache params for the pipeline model.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
get_num_layers()
classmethod get_num_layers(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
load_model()
load_model()
Loads the compiled Gemma3 model using the ModuleV3 API.
prepare_initial_token_inputs()
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs to be passed to execute().
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and kv_cache_inputs (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
prepare_next_token_inputs()
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!