IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.llama4

Llama 4 (text-only) transformer architecture for text generation.

Llama4Config​

class max.pipelines.architectures.llama4.Llama4Config(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, head_dim, rope_theta, rope_scaling_params, max_seq_len, intermediate_size, intermediate_size_mlp, vocab_size, dtype, model_quantization_encoding, quantization_config, kv_params, devices, num_local_experts, num_experts_per_tok, moe_layers, no_rope_layers, attention_chunk_size, use_qk_norm, attn_temperature_tuning, floor_scale, attn_scale, attention_bias, attention_multiplier, rms_norm_eps=1e-05, norm_dtype=None, tie_word_embeddings=False, quant_config=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, data_parallel_degree=1)

source

Bases: ArchConfigWithStoredKVParams, ArchConfigWithKVCache

Model configuration for Llama4 (text-only) graph construction.

Parameters:

attention_bias​

attention_bias: bool

source

attention_chunk_size​

attention_chunk_size: int

source

attention_multiplier​

attention_multiplier: float

source

attn_scale​

attn_scale: float

source

attn_temperature_tuning​

attn_temperature_tuning: bool

source

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)

source

Bounds max_length by the text config’s max_position_embeddings.

The multimodal Llama4Config exposes max_position_embeddings only under text_config; the base implementation reads it off the top-level config, so route it through get_text_config() first.

Parameters:

Return type:

int

construct_kv_params()​

classmethod construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Default KV params for standard grouped attention.

Parameters:

Return type:

KVCacheParams

data_parallel_degree​

data_parallel_degree: int = 1

source

devices​

devices: list[DeviceRef]

source

dtype​

dtype: DType

source

finalize()​

finalize(huggingface_config, state_dict, return_logits, return_hidden_states=ReturnHiddenStates.NONE)

source

Sets parameters that depend on the loaded state dict.

Parameters:

Return type:

None

floor_scale​

floor_scale: float

source

get_head_dim()​

static get_head_dim(huggingface_config)

source

Attention head size from head_dim or hidden_size // num_attention_heads.

Parameters:

huggingface_config (AutoConfig)

Return type:

int

get_num_layers()​

static get_num_layers(huggingface_config)

source

Layer count for the decoder stack (override when HF uses a different field).

Parameters:

huggingface_config (AutoConfig)

Return type:

int

head_dim​

head_dim: int

source

hidden_size​

hidden_size: int

source

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)

source

Parameters:

Return type:

Self

intermediate_size​

intermediate_size: int

source

intermediate_size_mlp​

intermediate_size_mlp: int

source

kv_params​

kv_params: KVCacheParams

source

max_seq_len​

max_seq_len: int

source

model_quantization_encoding​

model_quantization_encoding: QuantizationEncoding | None

source

moe_layers​

moe_layers: list[int]

source

no_rope_layers​

no_rope_layers: list[int]

source

norm_dtype​

norm_dtype: DType | None = None

source

num_attention_heads​

num_attention_heads: int

source

num_experts_per_tok​

num_experts_per_tok: int

source

num_hidden_layers​

num_hidden_layers: int

source

num_key_value_heads​

num_key_value_heads: int

source

num_local_experts​

num_local_experts: int

source

quant_config​

quant_config: QuantConfig | None = None

source

quantization_config​

quantization_config: QuantizationConfig | None

source

return_hidden_states​

return_hidden_states: ReturnHiddenStates = 'none'

source

return_logits​

return_logits: ReturnLogits = 'last_token'

source

rms_norm_eps​

rms_norm_eps: float = 1e-05

source

rope_scaling_params​

rope_scaling_params: Llama3RopeScalingParams | None

source

rope_theta​

rope_theta: float

source

tie_word_embeddings​

tie_word_embeddings: bool = False

source

use_qk_norm​

use_qk_norm: bool

source

vocab_size​

vocab_size: int

source

Llama4Model​

class max.pipelines.architectures.llama4.Llama4Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, max_batch_size=1)

source

Bases: LlamaModelBase

Llama4 text-only pipeline model.

Reuses LlamaModelBase for input preparation, execution, and KV-cache management, and overrides only graph construction to build the Llama4 nn.Module.

Parameters:

attention_bias​

attention_bias: bool = False

source

Whether to use attention bias.

model_config_cls​

model_config_cls

source

alias of Llama4Config

norm_method​

norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'rms_norm'

source

Normalization layer.