IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.gemma4_assistant

Gemma4AssistantConfig​

class max.pipelines.architectures.gemma4_assistant.Gemma4AssistantConfig(backbone_hidden_size=5376, hidden_size=1024, num_hidden_layers=4, num_attention_heads=32, num_key_value_heads=16, num_global_key_value_heads=4, head_dim=256, global_head_dim=512, intermediate_size=8192, vocab_size=262144, rms_norm_eps=1e-06, hidden_activation='gelu_pytorch_tanh', layer_types=<factory>, sliding_window=1024, sliding_window_rope_theta=10000.0, global_rope_theta=1000000.0, global_rope_scaling=None, attention_k_eq_v=True, num_kv_shared_layers=4, max_position_embeddings=262144, devices=<factory>, dtype=bfloat16, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: object

Configuration for the Gemma4 Assistant draft model.

The assistant model is a lightweight decoder that performs cross-attention against the target (backbone) model’s KV cache. It has no K/V projection weights of its own.

Parameters:

  • backbone_hidden_size (int)
  • hidden_size (int)
  • num_hidden_layers (int)
  • num_attention_heads (int)
  • num_key_value_heads (int)
  • num_global_key_value_heads (int)
  • head_dim (int)
  • global_head_dim (int)
  • intermediate_size (int)
  • vocab_size (int)
  • rms_norm_eps (float)
  • hidden_activation (str)
  • layer_types (list[str])
  • sliding_window (int)
  • sliding_window_rope_theta (float)
  • global_rope_theta (float)
  • global_rope_scaling (ProportionalScalingParams | None)
  • attention_k_eq_v (bool)
  • num_kv_shared_layers (int)
  • max_position_embeddings (int)
  • devices (list[DeviceRef])
  • dtype (DType)
  • return_logits (ReturnLogits)
  • return_hidden_states (ReturnHiddenStates)

attention_k_eq_v​

attention_k_eq_v: bool = True

source

Whether K and V projections are shared in the target model.

backbone_hidden_size​

backbone_hidden_size: int = 5376

source

Hidden dimension of the target (backbone) model.

devices​

devices: list[DeviceRef]

source

Devices to place weights and run computation on.

dtype​

dtype: DType = 80

source

Data type for model weights and activations.

get_max_seq_len()​

get_max_seq_len()

source

Return type:

int

global_head_dim​

global_head_dim: int = 512

source

Per-head dimension for global attention.

global_rope_scaling​

global_rope_scaling: ProportionalScalingParams | None = None

source

Proportional scaling config for global RoPE.

global_rope_theta​

global_rope_theta: float = 1000000.0

source

RoPE theta for global attention.

head_dim​

head_dim: int = 256

source

Per-head dimension for sliding window attention.

hidden_activation​

hidden_activation: str = 'gelu_pytorch_tanh'

source

Activation function for the MLP.

hidden_size​

hidden_size: int = 1024

source

Hidden dimension of the assistant model.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Parameters:

Return type:

Self

intermediate_size​

intermediate_size: int = 8192

source

Feed-forward intermediate dimension.

layer_types​

layer_types: list[str]

source

Per-layer attention type specification.

max_position_embeddings​

max_position_embeddings: int = 262144

source

Maximum sequence length supported by position embeddings.

num_attention_heads​

num_attention_heads: int = 32

source

Number of query attention heads.

num_global_key_value_heads​

num_global_key_value_heads: int = 4

source

Number of key/value heads for global attention in the target.

num_hidden_layers​

num_hidden_layers: int = 4

source

Number of decoder layers in the assistant model.

num_key_value_heads​

num_key_value_heads: int = 16

source

Number of key/value heads for sliding window attention in the target.

num_kv_shared_layers​

num_kv_shared_layers: int = 4

source

Number of KV-shared layers.

return_hidden_states​

return_hidden_states: ReturnHiddenStates = 'none'

source

Which hidden states to return from the model.

return_logits​

return_logits: ReturnLogits = 'last_token'

source

Which logits to return from the model.

rms_norm_eps​

rms_norm_eps: float = 1e-06

source

Epsilon for RMS normalization.

sliding_window​

sliding_window: int = 1024

source

Sliding window size for local attention layers.

sliding_window_rope_theta​

sliding_window_rope_theta: float = 10000.0

source

RoPE theta for sliding window attention.

vocab_size​

vocab_size: int = 262144

source

Vocabulary size.