Python module
registry
Model registry, for tracking various model variants.
PipelineRegistry
class max.pipelines.lib.registry.PipelineRegistry(architectures)
-
Parameters:
-
architectures (list[SupportedArchitecture])
get_active_huggingface_config()
get_active_huggingface_config(huggingface_repo)
Retrieves or creates a cached HuggingFace AutoConfig for the given model configuration.
This method maintains a cache of HuggingFace configurations to avoid reloading them unnecessarily which incurs a huggingface hub API call. If a config for the given model hasn’t been loaded before, it will create a new one using AutoConfig.from_pretrained() with the model’s settings.
Note: The cache key (HuggingFaceRepo) includes trust_remote_code in its hash, so configs with different trust settings are cached separately. For multiprocessing, each worker process has its own registry instance with an empty cache, so configs are loaded fresh in each worker.
-
Parameters:
-
huggingface_repo (HuggingFaceRepo) – The HuggingFaceRepo containing the model.
-
Returns:
-
The HuggingFace configuration object for the model.
-
Return type:
-
AutoConfig
get_active_tokenizer()
get_active_tokenizer(huggingface_repo)
Retrieves or creates a cached HuggingFace AutoTokenizer for the given model configuration.
This method maintains a cache of HuggingFace tokenizers to avoid reloading them unnecessarily which incurs a huggingface hub API call. If a tokenizer for the given model hasn’t been loaded before, it will create a new one using AutoTokenizer.from_pretrained() with the model’s settings.
-
Parameters:
-
huggingface_repo (HuggingFaceRepo) – The HuggingFaceRepo containing the model.
-
Returns:
-
The HuggingFace tokenizer for the model.
-
Return type:
-
PreTrainedTokenizer | PreTrainedTokenizerFast
register()
register(architecture, *, allow_override=False)
Add new architecture to registry.
-
Parameters:
-
- architecture (SupportedArchitecture)
- allow_override (bool)
-
Return type:
-
None
reset()
reset()
-
Return type:
-
None
retrieve()
retrieve(pipeline_config, task=PipelineTask.TEXT_GENERATION, override_architecture=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- task (PipelineTask)
- override_architecture (str | None)
-
Return type:
-
tuple[PipelineTokenizer[Any, Any, Any], PipelineTypes]
retrieve_architecture()
retrieve_architecture(huggingface_repo, use_module_v3=False)
-
Parameters:
-
- huggingface_repo (HuggingFaceRepo)
- use_module_v3 (bool)
-
Return type:
-
SupportedArchitecture | None
retrieve_context_type()
retrieve_context_type(pipeline_config)
Retrieve the context class type associated with the architecture for the given pipeline configuration.
The context type defines how the pipeline manages request state and inputs during model execution. Different architectures may use different context implementations that adhere to either the TextGenerationContext or EmbeddingsContext protocol.
-
Parameters:
-
pipeline_config (PipelineConfig) – The configuration for the pipeline.
-
Returns:
-
The context class type associated with the architecture, which implements either the TextGenerationContext or EmbeddingsContext protocol.
-
Raises:
-
ValueError – If no supported architecture is found for the given model repository.
-
Return type:
retrieve_factory()
retrieve_factory(pipeline_config, task=PipelineTask.TEXT_GENERATION, override_architecture=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- task (PipelineTask)
- override_architecture (str | None)
-
Return type:
-
tuple[PipelineTokenizer[Any, Any, Any], Callable[[], PipelineTypes]]
retrieve_pipeline_task()
retrieve_pipeline_task(pipeline_config)
Retrieve the pipeline task associated with the architecture for the given pipeline configuration.
-
Parameters:
-
pipeline_config (PipelineConfig) – The configuration for the pipeline.
-
Returns:
-
The task associated with the architecture.
-
Return type:
-
Raises:
-
ValueError – If no supported architecture is found for the given model repository.
retrieve_tokenizer()
retrieve_tokenizer(pipeline_config, override_architecture=None)
Retrieves a tokenizer for the given pipeline configuration.
-
Parameters:
-
- pipeline_config (PipelineConfig) – Configuration for the pipeline
- override_architecture (str | None) – Optional architecture override string
-
Returns:
-
The configured tokenizer
-
Return type:
-
Raises:
-
ValueError – If no architecture is found
SupportedArchitecture
class max.pipelines.lib.registry.SupportedArchitecture(name, example_repo_ids, default_encoding, supported_encodings, pipeline_model, task, tokenizer, default_weights_format, context_type, rope_type=RopeType.none, weight_adapters=<factory>, multi_gpu_supported=False, required_arguments=<factory>, context_validators=<factory>, supports_empty_batches=False, requires_max_batch_context_length=False)
Represents a model architecture configuration for MAX pipelines.
This class defines all the necessary components and settings required to support a specific model architecture within the MAX pipeline system. Each SupportedArchitecture instance encapsulates the model implementation, tokenizer, supported encodings, and other architecture-specific configuration.
New architectures should be registered into the PipelineRegistry
using the register() method.
Example:
my_architecture = SupportedArchitecture(
name="MyModelForCausalLM", # Must match your Hugging Face model class name
example_repo_ids=[
"your-org/your-model-name", # Add example model repository IDs
],
default_encoding=SupportedEncoding.q4_k,
supported_encodings={
SupportedEncoding.q4_k: [KVCacheStrategy.PAGED],
SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED],
# Add other encodings your model supports
},
pipeline_model=MyModel,
tokenizer=TextTokenizer,
default_weights_format=WeightsFormat.safetensors,
rope_type=RopeType.none,
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
# Add other weight formats if needed
},
multi_gpu_supported=True, # Set based on your implementation capabilities
required_arguments={"some_arg": True},
task=PipelineTask.TEXT_GENERATION,
)-
Parameters:
-
- name (str)
- example_repo_ids (list[str])
- default_encoding (SupportedEncoding)
- supported_encodings (dict[SupportedEncoding, list[KVCacheStrategy]])
- pipeline_model (type[PipelineModel[Any]])
- task (PipelineTask)
- tokenizer (Callable[[...], PipelineTokenizer[Any, Any, Any]])
- default_weights_format (WeightsFormat)
- context_type (type[TextGenerationContext] | type[EmbeddingsContext])
- rope_type (RopeType)
- weight_adapters (dict[WeightsFormat, Callable[[...], dict[str, WeightData]]])
- multi_gpu_supported (bool)
- required_arguments (dict[str, bool | int | float])
- context_validators (list[Callable[[TextContext | TextAndVisionContext], None]])
- supports_empty_batches (bool)
- requires_max_batch_context_length (bool)
context_type
context_type: type[TextGenerationContext] | type[EmbeddingsContext]
The context class type that this architecture uses for managing request state and inputs.
This should be a class (not an instance) that implements either the TextGenerationContext or EmbeddingsContext protocol, defining how the pipeline processes and tracks requests.
context_validators
context_validators: list[Callable[[TextContext | TextAndVisionContext], None]]
A list of callable validators that verify context inputs before model execution.
These validators are called during context creation to ensure inputs meet model-specific requirements. Validators should raise InputError for invalid inputs, providing early error detection before expensive model operations.
def validate_single_image(context: TextContext | TextAndVisionContext) -> None:
if isinstance(context, TextAndVisionContext):
if context.pixel_values and len(context.pixel_values) > 1:
raise InputError(f"Model supports only 1 image, got {len(context.pixel_values)}")
my_architecture = SupportedArchitecture(
# ... other fields ...
context_validators=[validate_single_image],
)default_encoding
default_encoding: SupportedEncoding
The default quantization encoding to use when no specific encoding is requested.
default_weights_format
default_weights_format: WeightsFormat
The weights format expected by the pipeline_model.
example_repo_ids
A list of Hugging Face repository IDs that use this architecture for testing and validation purposes.
multi_gpu_supported
multi_gpu_supported: bool = False
Whether the architecture supports multi-GPU execution.
name
name: str
The name of the model architecture that must match the Hugging Face model class name.
pipeline_model
pipeline_model: type[PipelineModel[Any]]
The PipelineModel class that defines the model graph structure and execution logic.
required_arguments
A dictionary specifying required values for PipelineConfig options.
requires_max_batch_context_length
requires_max_batch_context_length: bool = False
Whether the architecture requires a max batch context length to be specified.
If True and max_batch_context_length is not specified, we will default to the max sequence length of the model.
rope_type
rope_type: RopeType = 'none'
The type of RoPE (Rotary Position Embedding) used by the model.
supported_encodings
supported_encodings: dict[SupportedEncoding, list[KVCacheStrategy]]
A dictionary mapping supported quantization encodings to their compatible KV cache strategies.
supports_empty_batches
supports_empty_batches: bool = False
Whether the architecture can handle empty batches during inference.
When set to True, the pipeline can process requests with zero-sized batches without errors. This is useful for certain execution modes and expert parallelism. Most architectures do not require empty batch support and should leave this as False.
task
task: PipelineTask
The pipeline task type that this architecture supports.
tokenizer
tokenizer: Callable[[...], PipelineTokenizer[Any, Any, Any]]
A callable that returns a PipelineTokenizer instance for preprocessing model inputs.
tokenizer_cls
property tokenizer_cls: type[PipelineTokenizer[Any, Any, Any]]
weight_adapters
weight_adapters: dict[WeightsFormat, Callable[[...], dict[str, WeightData]]]
A dictionary of weight format adapters for converting checkpoints from different formats to the default format.
get_pipeline_for_task()
max.pipelines.lib.registry.get_pipeline_for_task(task, pipeline_config)
-
Parameters:
-
- task (PipelineTask)
- pipeline_config (PipelineConfig)
-
Return type:
-
type[TextGenerationPipeline[TextContext]] | type[EmbeddingsPipeline] | type[AudioGeneratorPipeline] | type[StandaloneSpeculativeDecodingPipeline] | type[SpeechTokenGenerationPipeline] | type[EAGLESpeculativeDecodingPipeline]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!