For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python module
max.pipelines.architectures.gemma4_assistant
Gemma4AssistantConfigβ
class max.pipelines.architectures.gemma4_assistant.Gemma4AssistantConfig(backbone_hidden_size=5376, hidden_size=1024, num_hidden_layers=4, num_attention_heads=32, num_key_value_heads=16, num_global_key_value_heads=4, head_dim=256, global_head_dim=512, intermediate_size=8192, vocab_size=262144, rms_norm_eps=1e-06, hidden_activation='gelu_pytorch_tanh', layer_types=<factory>, sliding_window=1024, sliding_window_rope_theta=10000.0, global_rope_theta=1000000.0, global_rope_scaling=None, attention_k_eq_v=True, num_kv_shared_layers=4, max_position_embeddings=262144, devices=<factory>, dtype=bfloat16, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: object
Configuration for the Gemma4 Assistant draft model.
The assistant model is a lightweight decoder that performs cross-attention against the target (backbone) modelβs KV cache. It has no K/V projection weights of its own.
-
Parameters:
-
- backbone_hidden_size (int)
- hidden_size (int)
- num_hidden_layers (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- num_global_key_value_heads (int)
- head_dim (int)
- global_head_dim (int)
- intermediate_size (int)
- vocab_size (int)
- rms_norm_eps (float)
- hidden_activation (str)
- layer_types (list[str])
- sliding_window (int)
- sliding_window_rope_theta (float)
- global_rope_theta (float)
- global_rope_scaling (ProportionalScalingParams | None)
- attention_k_eq_v (bool)
- num_kv_shared_layers (int)
- max_position_embeddings (int)
- devices (list[DeviceRef])
- dtype (DType)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
attention_k_eq_vβ
attention_k_eq_v: bool = True
Whether K and V projections are shared in the target model.
backbone_hidden_sizeβ
backbone_hidden_size: int = 5376
Hidden dimension of the target (backbone) model.
devicesβ
Devices to place weights and run computation on.
dtypeβ
dtype: DType = 80
Data type for model weights and activations.
get_max_seq_len()β
get_max_seq_len()
-
Return type:
global_head_dimβ
global_head_dim: int = 512
Per-head dimension for global attention.
global_rope_scalingβ
global_rope_scaling: ProportionalScalingParams | None = None
Proportional scaling config for global RoPE.
global_rope_thetaβ
global_rope_theta: float = 1000000.0
RoPE theta for global attention.
head_dimβ
head_dim: int = 256
Per-head dimension for sliding window attention.
hidden_activationβ
hidden_activation: str = 'gelu_pytorch_tanh'
Activation function for the MLP.
hidden_sizeβ
hidden_size: int = 1024
Hidden dimension of the assistant model.
initialize()β
classmethod initialize(pipeline_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- model_config (MAXModelConfig | None)
-
Return type:
intermediate_sizeβ
intermediate_size: int = 8192
Feed-forward intermediate dimension.
layer_typesβ
Per-layer attention type specification.
max_position_embeddingsβ
max_position_embeddings: int = 262144
Maximum sequence length supported by position embeddings.
num_attention_headsβ
num_attention_heads: int = 32
Number of query attention heads.
num_global_key_value_headsβ
num_global_key_value_heads: int = 4
Number of key/value heads for global attention in the target.
num_hidden_layersβ
num_hidden_layers: int = 4
Number of decoder layers in the assistant model.
num_key_value_headsβ
num_key_value_heads: int = 16
Number of key/value heads for sliding window attention in the target.
num_kv_shared_layersβ
num_kv_shared_layers: int = 4
Number of KV-shared layers.
return_hidden_statesβ
return_hidden_states: ReturnHiddenStates = 'none'
Which hidden states to return from the model.
return_logitsβ
return_logits: ReturnLogits = 'last_token'
Which logits to return from the model.
rms_norm_epsβ
rms_norm_eps: float = 1e-06
Epsilon for RMS normalization.
sliding_windowβ
sliding_window: int = 1024
Sliding window size for local attention layers.
sliding_window_rope_thetaβ
sliding_window_rope_theta: float = 10000.0
RoPE theta for sliding window attention.
vocab_sizeβ
vocab_size: int = 262144
Vocabulary size.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!