Skip to main content

Python module

model_config

MAX model config classes.

MAXModelConfig

class max.pipelines.lib.model_config.MAXModelConfig(*, config_file=None, section_name=None, use_subgraphs=True, data_parallel_degree=1, max_length=None, model_path='', served_model_name=None, weight_path=<factory>, quantization_encoding=None, allow_safetensors_weights_fp32_bf6_bidirectional_cast=False, huggingface_model_revision='main', huggingface_weight_revision='main', trust_remote_code=False, device_specs=<factory>, force_download=False, vision_config_overrides=<factory>, rope_type=None, kv_cache=<factory>)

Bases: MAXModelConfigBase

Parameters:

  • config_file (str | None)
  • section_name (str | None)
  • use_subgraphs (bool)
  • data_parallel_degree (int)
  • max_length (int | None)
  • model_path (str)
  • served_model_name (str | None)
  • weight_path (list[Path])
  • quantization_encoding (Literal['float32', 'bfloat16', 'q4_k', 'q4_0', 'q6_k', 'float8_e4m3fn', 'float4_e2m1fnx2', 'gptq'] | None)
  • allow_safetensors_weights_fp32_bf6_bidirectional_cast (bool)
  • huggingface_model_revision (str)
  • huggingface_weight_revision (str)
  • trust_remote_code (bool)
  • device_specs (list[DeviceSpec])
  • force_download (bool)
  • vision_config_overrides (dict[str, Any])
  • rope_type (Literal['none', 'normal', 'neox', 'longrope', 'yarn'] | None)
  • kv_cache (KVCacheConfig)

allow_safetensors_weights_fp32_bf6_bidirectional_cast

allow_safetensors_weights_fp32_bf6_bidirectional_cast: bool

create_kv_cache_config()

create_kv_cache_config(**kv_cache_kwargs)

Create and set the KV cache configuration with the given parameters.

This method creates a new KVCacheConfig from the provided keyword arguments and automatically sets the cache_dtype based on the model’s quantization encoding (or any explicit override in kv_cache_kwargs).

Parameters:

**kv_cache_kwargs – Keyword arguments to pass to KVCacheConfig constructor. Common options include:

  • cache_strategy: The KV cache strategy (continuous, paged, etc.)
  • kv_cache_page_size: Number of tokens per page for paged cache
  • enable_prefix_caching: Whether to enable prefix caching
  • device_memory_utilization: Fraction of device memory to use
  • cache_dtype: Override for the cache data type

Return type:

None

data_parallel_degree

data_parallel_degree: int

default_device_spec

property default_device_spec: DeviceSpec

Returns the default device spec for the model.

This is the first device spec in the list, used for device spec checks throughout config validation.

Returns:

The default device spec for the model.

device_specs

device_specs: list[DeviceSpec]

diffusers_config

property diffusers_config: dict[str, Any] | None

Retrieve the diffusers config for diffusion pipelines.

Note: For multiprocessing, __getstate__ clears _diffusers_config before pickling. Each worker process will reload the config fresh.

Returns:

The diffusers config dict if this is a diffusion pipeline, None otherwise. The dict will have a structure with “_class_name” and “components” keys, where each component includes “class_name” and “config_dict” fields.

force_download

force_download: bool

generation_config

property generation_config: GenerationConfig

Retrieve the Hugging Face GenerationConfig for this model.

This property lazily loads the GenerationConfig from the model repository and caches it to avoid repeated remote fetches.

Returns:

The GenerationConfig for the model, containing generation parameters like max_length, temperature, top_p, etc. If loading fails, returns a default GenerationConfig.

graph_quantization_encoding

property graph_quantization_encoding: QuantizationEncoding | None

Converts the CLI encoding to a MAX Graph quantization encoding.

Returns:

The graph quantization encoding corresponding to the CLI encoding.

Raises:

ValueError – If no CLI encoding was specified.

huggingface_config

property huggingface_config: AutoConfig | None

Returns the Hugging Face model config (loaded on first access).

huggingface_model_repo

property huggingface_model_repo: HuggingFaceRepo

Returns the Hugging Face repo handle for the model.

huggingface_model_revision

huggingface_model_revision: str

huggingface_weight_repo

property huggingface_weight_repo: HuggingFaceRepo

Returns the Hugging Face repo handle for weight files.

huggingface_weight_repo_id

property huggingface_weight_repo_id: str

Returns the Hugging Face repo ID used for weight files.

huggingface_weight_revision

huggingface_weight_revision: str

kv_cache

kv_cache: KVCacheConfig

max_length

max_length: int | None

model_config

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_name

property model_name: str

Returns the served model name or model path.

model_path

model_path: str

model_post_init()

model_post_init(context, /)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that’s what pydantic-core passes when calling it.

Parameters:

  • self (BaseModel) – The BaseModel instance.
  • context (Any) – The context.

Return type:

None

quantization_encoding

quantization_encoding: SupportedEncoding | None

resolve()

resolve()

Validates and resolves the config.

This method is called after the model config is initialized, to ensure that all config fields have been initialized to a valid state. It will also set and update other fields which may not be determined / initialized in the default factory.

In order:

  1. Validate that the device_specs provided are available
  2. Parse the weight path(s) and initialize the _weights_repo_id

Return type:

None

rope_type

rope_type: RopeType | None

sampling_params_defaults

property sampling_params_defaults: SamplingParamsGenerationConfigDefaults

Returns sampling defaults derived from the generation config.

served_model_name

served_model_name: str | None

set_cache_dtype_given_quantization_encoding()

set_cache_dtype_given_quantization_encoding()

Determine the KV cache dtype based on quantization encoding configuration.

The dtype is determined in the following priority order:

  1. Explicit override from kv_cache.kv_cache_format (if set)
  2. Derived from the model’s quantization_encoding
  3. Falls back to float32 if no encoding is specified

Returns:

  • DType.float32 for float32, q4_k, q4_0, q6_k encodings
  • DType.bfloat16 for bfloat16, float8_e4m3fn, float4_e2m1fnx2, gptq encodings

Return type:

The DType to use for the KV cache. Typical values are

trust_remote_code

trust_remote_code: bool

use_subgraphs

use_subgraphs: bool

validate_and_resolve_quantization_encoding_weight_path()

validate_and_resolve_quantization_encoding_weight_path(default_encoding)

Verifies that the quantization encoding and weight path are consistent.

Parameters:

  • weight_path – The path to the weight file.
  • default_encoding (Literal['float32', 'bfloat16', 'q4_k', 'q4_0', 'q6_k', 'float8_e4m3fn', 'float4_e2m1fnx2', 'gptq']) – The default encoding to use if no encoding is provided.

Return type:

None

validate_and_resolve_rope_type()

validate_and_resolve_rope_type(arch_rope_type)

Resolves rope_type from architecture default if not set.

Parameters:

arch_rope_type (Literal['none', 'normal', 'neox', 'longrope', 'yarn'])

Return type:

None

validate_and_resolve_with_resolved_quantization_encoding()

validate_and_resolve_with_resolved_quantization_encoding(supported_encodings, default_weights_format)

Validates model path and weight path against resolved quantization encoding.

Also resolves the KV cache strategy and finalizes the encoding config.

Parameters:

  • supported_encodings (dict[Literal['float32', 'bfloat16', 'q4_k', 'q4_0', 'q6_k', 'float8_e4m3fn', 'float4_e2m1fnx2', 'gptq'], list[~typing.Literal['model_default', 'paged']]]) – A dictionary of supported encodings and their corresponding KV cache strategies.
  • default_weights_format (WeightsFormat) – The default weights format to use if no weights format is provided.

Return type:

None

validate_lora_compatibility()

validate_lora_compatibility()

Validates that LoRA configuration is compatible with model settings.

Raises:

ValueError – If LoRA is enabled but incompatible with current model configuration.

Return type:

None

validate_max_length()

classmethod validate_max_length(v)

Validate that max_length is non-negative if provided.

Parameters:

v (int | None)

Return type:

int | None

validate_multi_gpu_supported()

validate_multi_gpu_supported(multi_gpu_supported)

Validates that the model architecture supports multi-GPU inference.

Parameters:

multi_gpu_supported (bool) – Whether the model architecture supports multi-GPU inference.

Return type:

None

vision_config_overrides

vision_config_overrides: dict[str, Any]

weight_path

weight_path: list[Path]

weights_size()

weights_size()

Calculates the total size in bytes of all weight files in weight_path.

Attempts to find the weights locally first to avoid network calls, checking in the following order:

  1. If repo_type is RepoType.local, it checks if the path in weight_path exists directly as a local file path.
  2. Otherwise, if repo_type is RepoType.online, it first checks the local Hugging Face cache using huggingface_hub.try_to_load_from_cache(). If not found in the cache, it falls back to querying the Hugging Face Hub API via HuggingFaceRepo.size_of().

Returns:

The total size of all weight files in bytes.

Raises:

  • FileNotFoundError – If repo_type is RepoType.local and a file specified in weight_path is not found within the local repo directory.
  • ValueError – If HuggingFaceRepo.size_of() fails to retrieve the file size from the Hugging Face Hub API (e.g., file metadata not available or API error).
  • RuntimeError – If the determined repo_type is unexpected.

Return type:

int

MAXModelConfigBase

class max.pipelines.lib.model_config.MAXModelConfigBase(*, config_file=None, section_name=None)

Bases: ConfigFileModel

Abstract base class for all (required) MAX model configs.

This base class is used to configure a model to use for a pipeline, but also handy to sidestep the need to pass in optional fields when subclassing MAXModelConfig.

Parameters:

  • config_file (str | None)
  • section_name (str | None)

model_config

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Was this page helpful?