IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.olmo2_modulev3

OLMo 2 transformer architecture for text generation.

Olmo2Config​

class max.pipelines.architectures.olmo2_modulev3.Olmo2Config(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, tie_word_embeddings, attention_multiplier, embedding_multiplier, residual_multiplier, dtype, devices, interleaved_rope_weights, return_logits, kv_params)

source

Bases: ArchConfigWithPermissiveMaxSeqLen, ArchConfigWithStoredKVParams, ArchConfigWithKVCache

Configuration for Olmo2 models.

Contains parameters specific to the Olmo2 architecture, typically extracted from a HuggingFace configuration object.

Parameters:

attention_bias​

attention_bias: bool

source

Whether to use a bias in the attention projection layers.

attention_multiplier​

attention_multiplier: float

source

Scalar applied to attention scores.

calculate_attention_multiplier()​

static calculate_attention_multiplier(huggingface_config)

source

Parameters:

huggingface_config (AutoConfig)

Return type:

float

construct_kv_params()​

classmethod construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Olmo2 does not support data parallelism; use default grouped KV (no EAGLE).

Parameters:

Return type:

KVCacheParams

devices​

devices: list[DeviceRef]

source

Devices to run the model with.

dtype​

dtype: DType

source

DType of the model weights and input.

embedding_multiplier​

embedding_multiplier: float

source

Scalar applied to embeddings.

finalize()​

finalize(huggingface_config, state_dict, return_logits, return_hidden_states=ReturnHiddenStates.NONE, norm_method='rms_norm', attention_bias=False)

source

Define parameters that can’t be determined just from the pipeline config.

Parameters:

Return type:

None

get_num_layers()​

static get_num_layers(huggingface_config)

source

Layer count for the decoder stack (override when HF uses a different field).

Parameters:

huggingface_config (AutoConfig)

Return type:

int

head_dim​

head_dim: int

source

Dimension of each attention head.

hidden_size​

hidden_size: int

source

Dimension of the hidden representations.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

interleaved_rope_weights​

interleaved_rope_weights: bool

source

True if the rope weights are in interleaved complex format.

intermediate_size​

intermediate_size: int

source

Dimension of the MLP representations.

kv_params​

kv_params: KVCacheParams

source

KV cache parameters.

max_position_embeddings​

max_position_embeddings: int

source

The maximum sequence length that this model might ever be used with.

num_attention_heads​

num_attention_heads: int

source

Number of attention heads for each attention layer.

num_hidden_layers​

num_hidden_layers: int

source

Number of hidden layers in the Transformer decoder.

num_key_value_heads​

num_key_value_heads: int

source

Number of key_value heads for Grouped Query Attention.

residual_multiplier​

residual_multiplier: float

source

Scalar applied to residual connections.

return_logits​

return_logits: ReturnLogits

source

Whether to return the last token, all logits, or a variable number of logits.

rms_norm_eps​

rms_norm_eps: float

source

The epsilon used by the rms normalization layers.

rope_theta​

rope_theta: float

source

The base period of the RoPE embeddings.

tie_word_embeddings​

tie_word_embeddings: bool

source

Whether to tie weight embeddings.

vocab_size​

vocab_size: int

source

Vocabulary size of the Olmo2 model.

Olmo2Model​

class max.pipelines.architectures.olmo2_modulev3.Olmo2Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: Llama3Model

An Olmo2 pipeline model for text generation.

Parameters:

load_model()​

load_model()

source

Return type:

Callable[[…], Any]

model_config_cls​

model_config_cls

source

alias of Olmo2Config