Python module
max.pipelines.architectures.gemma3
Gemma 3 transformer architecture for text generation.
Gemma3Configβ
class max.pipelines.architectures.gemma3.Gemma3Config(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, query_pre_attn_scalar, sliding_window, final_logit_softcapping, attn_logit_softcapping, rope_scaling, rope_local_base_freq, sliding_window_pattern, dtype, devices, interleaved_rope_weights, return_logits=ReturnLogits.LAST_TOKEN, kv_params, tie_word_embeddings=False, quant_config=None)
Bases: ArchConfigWithKVCache
Represents the MAX Engine configuration for Gemma 3 models.
Contains parameters specific to the Gemma 3 architecture (typically extracted from HuggingFace configs), plus MAX-specific runtime settings and helpers.
-
Parameters:
-
- vocab_size (int)
- hidden_size (int)
- intermediate_size (int)
- num_hidden_layers (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- head_dim (int)
- hidden_activation (str)
- max_position_embeddings (int)
- rms_norm_eps (float)
- rope_theta (float)
- attention_bias (bool)
- query_pre_attn_scalar (float | None)
- sliding_window (int)
- final_logit_softcapping (float | None)
- attn_logit_softcapping (int | None)
- rope_scaling (LinearScalingParams | None)
- rope_local_base_freq (float)
- sliding_window_pattern (int)
- dtype (DType)
- devices (list[DeviceRef])
- interleaved_rope_weights (bool)
- return_logits (ReturnLogits)
- kv_params (KVCacheParams)
- tie_word_embeddings (bool)
- quant_config (QuantConfig | None)
attention_biasβ
attention_bias: bool
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attn_logit_softcappingβ
Scaling factor when applying tanh softcapping on the attention scores.
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the model.
Uses the max_length from the max.pipelines.config.PipelineConfig if provided,
otherwise falls back to the max_position_embeddings from the HuggingFace
configurationβs text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) β The HuggingFace model configuration object (
transformers.AutoConfig).
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
construct_kv_params()β
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Constructs the KV cache parameters from configuration objects.
-
Parameters:
-
- huggingface_config (AutoConfig) β The HuggingFace model configuration object (
transformers.AutoConfig). - n_devices β The number of devices the model will run on.
- kv_cache_config (KVCacheConfig) β The MAX Engine KV cache configuration settings (
max.pipelines.max_config.KVCacheConfig). - cache_dtype (DType) β The desired data type for the KV cache (
max.dtype.DType). - pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- huggingface_config (AutoConfig) β The HuggingFace model configuration object (
-
Returns:
-
The configured
max.pipelines.kv_cache.KVCacheParamsobject. -
Return type:
devicesβ
Devices to run the model with.
dtypeβ
dtype: DType
DType of the model weights and input.
final_logit_softcappingβ
Scaling factor when applying tanh softcapping on the logits.
finalize()β
finalize(huggingface_config, state_dict, return_logits, quant_config=None)
Define parameters that canβt be determined just from the pipeline config.
This method sets fields that require introspection of the model weights (state_dict), such as tie_word_embeddings and quant_config.
-
Parameters:
-
- huggingface_config (AutoConfig) β The HuggingFace model configuration object.
- state_dict (dict[str, WeightData]) β The modelβs state dictionary containing weights.
- return_logits (ReturnLogits) β Whether to return the last token, all tokens or a variable number of logits.
- quant_config (QuantConfig | None) β Scaled quantization configuration (optional).
-
Return type:
-
None
get_kv_params()β
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()β
static get_num_layers(huggingface_config)
Retrieves the number of hidden layers from the HuggingFace configuration.
-
Parameters:
-
huggingface_config (AutoConfig) β The HuggingFace model configuration object (
transformers.AutoConfig). -
Returns:
-
The number of hidden layers specified in the configurationβs text config.
-
Return type:
head_dimβ
head_dim: int
The attention head dimension.
hidden_activationβ
hidden_activation: str
The non-linear activation function (function or string) in the decoder. Will default to βgelu_tanhβ if not specified. βgelu_tanhβ uses an approximation of the βgeluβ activation function.
hidden_sizeβ
hidden_size: int
Dimension of the hidden representations.
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config)
Initializes a Gemma3Config instance from pipeline and HuggingFace configuration.
This method creates a config instance with all fields that can be determined from the pipeline and HuggingFace configuration, without needing the state_dict. Fields that depend on the state_dict (like tie_word_embeddings, quant_config) should be set via the finalize() method.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) β The HuggingFace model configuration object.
-
Returns:
-
An initialized
Gemma3Configinstance. -
Return type:
interleaved_rope_weightsβ
interleaved_rope_weights: bool
True if the rope weights are in interleaved complex format.
intermediate_sizeβ
intermediate_size: int
Dimension of the MLP representations.
kv_paramsβ
kv_params: KVCacheParams
KV cache parameters.
max_position_embeddingsβ
max_position_embeddings: int
The maximum sequence length that this model might ever be used with.
num_attention_headsβ
num_attention_heads: int
Number of attention heads for each attention layer in the Transformer decoder.
num_hidden_layersβ
num_hidden_layers: int
Number of hidden layers in the Transformer decoder.
num_key_value_headsβ
num_key_value_heads: int
Number of key_value heads that should be used to implement Grouped Query Attention.
quant_configβ
quant_config: QuantConfig | None = None
Scaled quantization configuration.
query_pre_attn_scalarβ
Scaling factor used on the attention scores.
return_logitsβ
return_logits: ReturnLogits = 'last_token'
Whether to return the last token, all logits, or a variable number of logits.
rms_norm_epsβ
rms_norm_eps: float
The epsilon used by the rms normalization layers.
rope_local_base_freqβ
rope_local_base_freq: float
The base period of the RoPE embeddings for local attention.
rope_scalingβ
rope_scaling: LinearScalingParams | None
Scaling configuration for the RoPE embeddings used in global attention.
rope_thetaβ
rope_theta: float
The base period of the RoPE embeddings.
sliding_windowβ
sliding_window: int
In the Gemma3 language model, every other layer uses sliding window attention. This is the size of the sliding window.
sliding_window_patternβ
sliding_window_pattern: int
Pattern for the sliding window attention.
tie_word_embeddingsβ
tie_word_embeddings: bool = False
Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.
vocab_sizeβ
vocab_size: int
Vocabulary size of the Gemma3Text model.
Gemma3Inputsβ
class max.pipelines.architectures.gemma3.Gemma3Inputs(tokens, input_row_offsets, signal_buffers, return_n_logits, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
A class representing inputs for the Gemma3 model.
This class encapsulates the input tensors required for the Gemma3 model execution.
-
Parameters:
input_row_offsetsβ
input_row_offsets: Buffer
Tensor containing the offsets for each row in the ragged input sequence.
return_n_logitsβ
return_n_logits: Buffer
Number of logits to return.
signal_buffersβ
Device buffers used for synchronization in communication collectives.
tokensβ
tokens: Buffer
Tensor containing the input token IDs.
Gemma3Modelβ
class max.pipelines.architectures.gemma3.Gemma3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: LogProbabilitiesMixin, AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextContext]
A Gemma 3 pipeline model for text generation.
This class integrates the Gemma 3 architecture with the MAX Engine pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The configuration settings for the entire pipeline.
- session (InferenceSession) β The MAX Engine inference session managing the runtime.
- devices (list[Device]) β A list of MAX Engine devices (
max.driver.Device) to run the model on. - kv_cache_config (KVCacheConfig) β Configuration settings for the Key-Value cache
(
max.pipelines.max_config.KVCacheConfig). - weights (Weights) β The model weights (
max.graph.weights.Weights). - adapter (WeightsAdapter | None) β An optional adapter to modify weights before loading
(
max.graph.weights.WeightsAdapter). - return_logits (ReturnLogits) β The number of top logits to return from the model execution.
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the Gemma 3 model.
Uses the max_length from the max.pipelines.config.PipelineConfig
if provided, otherwise falls back to the max_position_embeddings from
the HuggingFace configurationβs text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) β The HuggingFace model configuration object
(
transformers.AutoConfig).
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
execute()β
execute(model_inputs)
Executes the Gemma 3 model with the prepared inputs.
-
Parameters:
-
model_inputs (ModelInputs) β The prepared inputs for the model execution, typically including token IDs, attention masks/offsets, and KV cache inputs.
-
Returns:
-
An object containing the output logits from the model execution.
-
Return type:
get_kv_params()β
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Gets the parameters required to configure the KV cache for Gemma 3.
Delegates to the Gemma3Config.construct_kv_params static method.
-
Parameters:
-
- huggingface_config (AutoConfig) β The HuggingFace model configuration object
(
transformers.AutoConfig). - pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- devices (list[DeviceRef]) β The list of devices the model will run on.
- kv_cache_config (KVCacheConfig) β The MAX Engine KV cache configuration settings
(
max.pipelines.max_config.KVCacheConfig). - cache_dtype (DType) β The desired data type for the KV cache
(
max.dtype.DType).
- huggingface_config (AutoConfig) β The HuggingFace model configuration object
(
-
Returns:
-
The configured
max.pipelines.kv_cache.KVCacheParamsobject. -
Return type:
get_num_layers()β
classmethod get_num_layers(huggingface_config)
Gets the number of hidden layers from the HuggingFace configuration.
Delegates to the Gemma3Config.get_num_layers static method.
-
Parameters:
-
huggingface_config (AutoConfig) β The HuggingFace model configuration object (
transformers.AutoConfig). -
Returns:
-
The number of hidden layers.
-
Return type:
load_model()β
load_model(session)
Loads the compiled Gemma 3 model into the MAX Engine session.
-
Parameters:
-
session (InferenceSession) β The MAX Engine inference session.
-
Returns:
-
The loaded MAX Engine model object.
-
Return type:
modelβ
model: Model
The compiled and initialized MAX Engine model ready for inference.
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs for the first execution pass of the Gemma 3 model.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]]) β A sequence of sequences of
TextContextobjects representing the input prompts for each replica. - kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None) β Optional inputs required by the KV cache manager.
- return_n_logits (int)
- replica_batches (Sequence[Sequence[TextContext]]) β A sequence of sequences of
-
Returns:
-
The prepared
ModelInputsobject for the initial execution step. -
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the inputs for subsequent execution steps in a multi-step generation.
-
Parameters:
-
- next_tokens (Buffer) β The tensor containing the token IDs generated in the previous step.
- prev_model_inputs (ModelInputs) β The
ModelInputsused in the previous execution step.
-
Returns:
-
The prepared
ModelInputsobject for the next execution step. -
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!