Skip to main content

Python module

max.pipelines.architectures.olmo

OLMo transformer architecture for text generation.

OlmoConfig

class max.pipelines.architectures.olmo.OlmoConfig(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, rope_theta, rope_scaling_params, max_seq_len, intermediate_size, interleaved_rope_weights, vocab_size, dtype, model_quantization_encoding, quantization_config, kv_params, return_logits=ReturnLogits.LAST_TOKEN, norm_method='rms_norm', norm_dtype=None, attention_bias=False, rms_norm_eps=None, tie_word_embeddings=False, stacked_mlp=False, stacked_qkv=False, attention_multiplier, embedding_multiplier, residual_multiplier, devices, clip_qkv, quant_config=None, lora_config=None, longrope_scaling_params=None, logits_scaling=1.0, return_hidden_states=ReturnHiddenStates.NONE, use_subgraphs=True, data_parallel_degree=1)

source

Bases: Llama3Config

Model configuration for Olmo graph construction/execution.

Parameters:

finalize()

finalize(huggingface_config, state_dict, return_logits, return_hidden_states=ReturnHiddenStates.NONE, norm_method='rms_norm', attention_bias=False)

source

Define parameters that can’t be determined just from the pipeline config.

Parameters:

Return type:

None

OlmoModel

class max.pipelines.architectures.olmo.OlmoModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: LlamaModelBase

Olmo pipeline model implementation.

Parameters:

norm_method

norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'layer_norm'

source

Normalization layer.