Python module
max.pipelines.architectures.gpt_oss
GPT-OSS mixture-of-experts architecture for text generation.
GptOssConfig
class max.pipelines.architectures.gpt_oss.GptOssConfig(*, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, num_key_value_heads, head_dim, hidden_activation, max_position_embeddings, rms_norm_eps, rope_theta, attention_bias, sliding_window, num_local_experts, num_experts_per_tok, router_aux_loss_coef, layer_types, attention_dropout, rope_scaling, query_pre_attn_scalar, final_logit_softcapping, attn_logit_softcapping, swiglu_limit, dtype, devices, interleaved_rope_weights, kv_params, quant_config=None, tie_word_embeddings=False, return_logits=ReturnLogits.LAST_TOKEN)
Bases: ArchConfigWithKVCache
Configuration for GPT OSS models.
Contains parameters specific to the GPT OSS architecture, typically extracted from a HuggingFace configuration object’s text config.
-
Parameters:
-
- vocab_size (int)
- hidden_size (int)
- intermediate_size (int)
- num_hidden_layers (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- head_dim (int)
- hidden_activation (str)
- max_position_embeddings (int)
- rms_norm_eps (float)
- rope_theta (float)
- attention_bias (bool)
- sliding_window (int)
- num_local_experts (int)
- num_experts_per_tok (int)
- router_aux_loss_coef (float)
- layer_types (list[str])
- attention_dropout (float)
- rope_scaling (YarnScalingParams)
- query_pre_attn_scalar (float | None)
- final_logit_softcapping (float | None)
- attn_logit_softcapping (float | None)
- swiglu_limit (float)
- dtype (DType)
- devices (list[DeviceRef])
- interleaved_rope_weights (bool)
- kv_params (KVCacheParams)
- quant_config (QuantConfig | None)
- tie_word_embeddings (bool)
- return_logits (ReturnLogits)
attention_bias
attention_bias: bool
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout
attention_dropout: float
Dropout probability for attention weights.
attn_logit_softcapping
Softcapping value for attention logits.
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the model.
Uses the max_length from the max.pipelines.config.PipelineConfig if provided,
otherwise falls back to the max_position_embeddings from the HuggingFace
configuration’s text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig).
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
construct_kv_params()
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Constructs the KV cache parameters from configuration objects.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig). - devices (list[DeviceRef]) – The list of devices the model will run on.
- kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings (
max.pipelines.max_config.KVCacheConfig). - cache_dtype (DType) – The desired data type for the KV cache (
max.dtype.DType). - pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig) – The HuggingFace model configuration object (
-
Returns:
-
The configured
max.pipelines.kv_cache.KVCacheParamsobject. -
Return type:
devices
Devices to run the model with.
dtype
dtype: DType
DType of the model weights and input.
final_logit_softcapping
Softcapping value for final logits.
finalize()
finalize(huggingface_config, state_dict, return_logits)
Define parameters that can’t be determined just from the pipeline config.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object.
- state_dict (dict[str, WeightData]) – The model’s state dictionary containing weights.
- return_logits (ReturnLogits) – Whether to return the last token, all tokens or a variable number of logits.
-
Return type:
-
None
get_kv_params()
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()
static get_num_layers(huggingface_config)
Retrieves the number of hidden layers from the HuggingFace configuration.
-
Parameters:
-
huggingface_config (AutoConfig) – The HuggingFace model configuration object (
transformers.AutoConfig). -
Returns:
-
The number of hidden layers specified in the configuration.
-
Return type:
head_dim
head_dim: int
The attention head dimension.
hidden_activation
hidden_activation: str
The non-linear activation function (function or string) in the decoder. Will default to “gelu_tanh” if not specified. “gelu_tanh” uses an approximation of the “gelu” activation function.
hidden_size
hidden_size: int
Dimension of the hidden representations.
initialize()
classmethod initialize(pipeline_config, model_config=None)
Initializes a GptOssConfig instance from pipeline configuration.
This method creates a config instance with all fields that can be determined from the pipeline configuration, without needing the state_dict. Fields that depend on the state_dict (like tie_word_embeddings) should be set via the finalize() method.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- model_config (MAXModelConfig | None)
-
Returns:
-
An initialized GptOssConfig instance.
-
Return type:
interleaved_rope_weights
interleaved_rope_weights: bool
True if the rope weights are in interleaved complex format.
intermediate_size
intermediate_size: int
Dimension of the MLP representations.
kv_params
kv_params: KVCacheParams
KV cache parameters.
layer_types
Type of attention for each layer (‘full_attention’ or ‘sliding_attention’).
max_position_embeddings
max_position_embeddings: int
The maximum sequence length that this model might ever be used with.
num_attention_heads
num_attention_heads: int
Number of attention heads for each attention layer in the Transformer decoder.
num_experts_per_tok
num_experts_per_tok: int
Number of experts selected per token in MoE layers.
num_hidden_layers
num_hidden_layers: int
Number of hidden layers in the Transformer decoder.
num_key_value_heads
num_key_value_heads: int
Number of key_value heads that should be used to implement Grouped Query Attention.
num_local_experts
num_local_experts: int
Number of experts in each MoE layer.
quant_config
quant_config: QuantConfig | None = None
Float8/Float4 quantization configuration, if applicable.
query_pre_attn_scalar
Scalar applied to queries before attention computation.
return_logits
return_logits: ReturnLogits = 'last_token'
Whether to return the last token, all logits, or a variable number of logits.
rms_norm_eps
rms_norm_eps: float
The epsilon used by the rms normalization layers.
rope_scaling
rope_scaling: YarnScalingParams
Scaling configuration for the RoPE embeddings used in global attention.
rope_theta
rope_theta: float
The base period of the RoPE embeddings.
router_aux_loss_coef
router_aux_loss_coef: float
Coefficient for the auxiliary load balancing loss in MoE layers.
sliding_window
sliding_window: int
In the GPT OSS language model, specific layers use sliding window attention. This is the size of the sliding window.
swiglu_limit
swiglu_limit: float
Clamping limit for SwiGLU activation in MoE layers.
tie_word_embeddings
tie_word_embeddings: bool = False
Whether to tie weight embeddings. When true, the output linear layer uses the same weight as the embedding layer.
vocab_size
vocab_size: int
Vocabulary size of the GPT OSS model.
GptOssModel
class max.pipelines.architectures.gpt_oss.GptOssModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: AlwaysSignalBuffersMixin, PipelineModelWithKVCache[TextContext]
A GPT OSS pipeline model for text generation.
This class integrates the GPT OSS architecture with the MAX Engine pipeline infrastructure, handling model loading, KV cache management, and input preparation for inference.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The configuration settings for the entire pipeline.
- session (InferenceSession) – The MAX Engine inference session managing the runtime.
- devices (list[Device]) – A list of MAX Engine devices (
max.driver.Device) to run the model on. - kv_cache_config (KVCacheConfig) – Configuration settings for the Key-Value cache
(
max.pipelines.max_config.KVCacheConfig). - weights (Weights) – The model weights (
max.graph.weights.Weights). - adapter (WeightsAdapter | None) – An optional adapter to modify weights before loading
(
max.graph.weights.WeightsAdapter). - return_logits (ReturnLogits) – The number of top logits to return from the model execution.
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the GPT OSS model.
Uses the max_length from the max.pipelines.config.PipelineConfig
if provided, otherwise falls back to the max_position_embeddings from
the HuggingFace configuration’s text config.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- huggingface_config (AutoConfig) – The HuggingFace model configuration object
(
transformers.AutoConfig).
-
Returns:
-
The calculated maximum sequence length.
-
Return type:
estimate_activation_memory()
classmethod estimate_activation_memory(pipeline_config, huggingface_config)
Estimates the activation memory required for model execution.
This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.
The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.
-
Parameters:
-
- pipeline_config (PipelineConfig) – Pipeline configuration
- huggingface_config (AutoConfig) – Hugging Face model configuration
-
Returns:
-
Estimated activation memory in bytes
-
Return type:
execute()
execute(model_inputs)
Executes the GPT OSS model with the prepared inputs.
-
Parameters:
-
model_inputs (ModelInputs) – The prepared inputs for the model execution, typically including token IDs, attention masks/offsets, and KV cache inputs.
-
Returns:
-
An object containing the output logits from the model execution.
-
Return type:
get_kv_params()
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Gets the parameters required to configure the KV cache for Gemma 3.
Delegates to the GptOssConfig.construct_kv_params static method.
-
Parameters:
-
- huggingface_config (AutoConfig) – The HuggingFace model configuration object
(
transformers.AutoConfig). - pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
- devices (list[DeviceRef]) – The list of devices the model will run on.
- kv_cache_config (KVCacheConfig) – The MAX Engine KV cache configuration settings
(
max.pipelines.max_config.KVCacheConfig). - cache_dtype (DType) – The desired data type for the KV cache
(
max.dtype.DType).
- huggingface_config (AutoConfig) – The HuggingFace model configuration object
(
-
Returns:
-
The configured
max.pipelines.kv_cache.KVCacheParamsobject. -
Return type:
load_model()
load_model(session)
Loads the compiled GPT OSS model into the MAX Engine session.
-
Parameters:
-
session (InferenceSession) – The MAX Engine inference session.
-
Returns:
-
The loaded MAX Engine model object.
-
Return type:
model
model: Model
The compiled and initialized MAX Engine model ready for inference.
prepare_initial_token_inputs()
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs for the first execution pass of the GPT OSS model.
-
Parameters:
-
- context_batch – A sequence of
TextContextobjects representing the input prompts. - kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None) – Optional inputs required by the KV cache manager.
- replica_batches (Sequence[Sequence[TextContext]])
- return_n_logits (int)
- context_batch – A sequence of
-
Returns:
-
The prepared
ModelInputsobject for the initial execution step. -
Return type:
prepare_next_token_inputs()
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the inputs for subsequent execution steps in a multi-step generation.
-
Parameters:
-
- next_tokens (Buffer) – The tensor containing the token IDs generated in the previous step.
- prev_model_inputs (ModelInputs) – The
ModelInputsused in the previous execution step.
-
Returns:
-
The prepared
ModelInputsobject for the next execution step. -
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!