For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python module
max.pipelines.architectures.llama4
Llama 4 (text-only) transformer architecture for text generation.
Llama4Configβ
class max.pipelines.architectures.llama4.Llama4Config(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, head_dim, rope_theta, rope_scaling_params, max_seq_len, intermediate_size, intermediate_size_mlp, vocab_size, dtype, model_quantization_encoding, quantization_config, kv_params, devices, num_local_experts, num_experts_per_tok, moe_layers, no_rope_layers, attention_chunk_size, use_qk_norm, attn_temperature_tuning, floor_scale, attn_scale, attention_bias, attention_multiplier, rms_norm_eps=1e-05, norm_dtype=None, tie_word_embeddings=False, quant_config=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, data_parallel_degree=1)
Bases: ArchConfigWithStoredKVParams, ArchConfigWithKVCache
Model configuration for Llama4 (text-only) graph construction.
-
Parameters:
-
- hidden_size (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- num_hidden_layers (int)
- head_dim (int)
- rope_theta (float)
- rope_scaling_params (Llama3RopeScalingParams | None)
- max_seq_len (int)
- intermediate_size (int)
- intermediate_size_mlp (int)
- vocab_size (int)
- dtype (DType)
- model_quantization_encoding (QuantizationEncoding | None)
- quantization_config (QuantizationConfig | None)
- kv_params (KVCacheParams)
- devices (list[DeviceRef])
- num_local_experts (int)
- num_experts_per_tok (int)
- moe_layers (list[int])
- no_rope_layers (list[int])
- attention_chunk_size (int)
- use_qk_norm (bool)
- attn_temperature_tuning (bool)
- floor_scale (float)
- attn_scale (float)
- attention_bias (bool)
- attention_multiplier (float)
- rms_norm_eps (float)
- norm_dtype (DType | None)
- tie_word_embeddings (bool)
- quant_config (QuantConfig | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
- data_parallel_degree (int)
attention_biasβ
attention_bias: bool
attention_chunk_sizeβ
attention_chunk_size: int
attention_multiplierβ
attention_multiplier: float
attn_scaleβ
attn_scale: float
attn_temperature_tuningβ
attn_temperature_tuning: bool
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)
Bounds max_length by the text configβs max_position_embeddings.
The multimodal Llama4Config exposes max_position_embeddings only
under text_config; the base implementation reads it off the
top-level config, so route it through get_text_config() first.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
construct_kv_params()β
classmethod construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Default KV params for standard grouped attention.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
data_parallel_degreeβ
data_parallel_degree: int = 1
devicesβ
dtypeβ
dtype: DType
finalize()β
finalize(huggingface_config, state_dict, return_logits, return_hidden_states=ReturnHiddenStates.NONE)
Sets parameters that depend on the loaded state dict.
-
Parameters:
-
- huggingface_config (AutoConfig)
- state_dict (dict[str, WeightData])
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
-
Return type:
-
None
floor_scaleβ
floor_scale: float
get_head_dim()β
static get_head_dim(huggingface_config)
Attention head size from head_dim or hidden_size // num_attention_heads.
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
get_num_layers()β
static get_num_layers(huggingface_config)
Layer count for the decoder stack (override when HF uses a different field).
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
head_dimβ
head_dim: int
hidden_sizeβ
hidden_size: int
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
intermediate_sizeβ
intermediate_size: int
intermediate_size_mlpβ
intermediate_size_mlp: int
kv_paramsβ
kv_params: KVCacheParams
max_seq_lenβ
max_seq_len: int
model_quantization_encodingβ
model_quantization_encoding: QuantizationEncoding | None
moe_layersβ
no_rope_layersβ
norm_dtypeβ
num_attention_headsβ
num_attention_heads: int
num_experts_per_tokβ
num_experts_per_tok: int
num_hidden_layersβ
num_hidden_layers: int
num_key_value_headsβ
num_key_value_heads: int
num_local_expertsβ
num_local_experts: int
quant_configβ
quant_config: QuantConfig | None = None
quantization_configβ
quantization_config: QuantizationConfig | None
return_hidden_statesβ
return_hidden_states: ReturnHiddenStates = 'none'
return_logitsβ
return_logits: ReturnLogits = 'last_token'
rms_norm_epsβ
rms_norm_eps: float = 1e-05
rope_scaling_paramsβ
rope_scaling_params: Llama3RopeScalingParams | None
rope_thetaβ
rope_theta: float
tie_word_embeddingsβ
tie_word_embeddings: bool = False
use_qk_normβ
use_qk_norm: bool
vocab_sizeβ
vocab_size: int
Llama4Modelβ
class max.pipelines.architectures.llama4.Llama4Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, max_batch_size=1)
Bases: LlamaModelBase
Llama4 text-only pipeline model.
Reuses LlamaModelBase
for input preparation, execution, and KV-cache management, and overrides
only graph construction to build the Llama4 nn.Module.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The configuration for this pipeline.
- session (InferenceSession) β The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
- max_batch_size (int)
attention_biasβ
attention_bias: bool = False
Whether to use attention bias.
model_config_clsβ
model_config_cls
alias of Llama4Config
norm_methodβ
norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'rms_norm'
Normalization layer.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!