
Serve custom model architectures
MAX comes with built-in support for popular model architectures like
Gemma3ForCausalLM
, Qwen2ForCausalLM
, and LlamaForCausalLM
, so you can
instantly deploy them by passing a specific Hugging Face model name to the max serve
command (explore our model
repo). You can also use MAX to
serve a custom model architecture with the max serve
command, which provides an
OpenAI-compatible API.
In this tutorial, you'll build a custom architecture for a model called MyModel
using our Python API, implement components for MAX, and serve your model with an
OpenAI-compatible endpoint. By the end of this tutorial, you'll understand how
to:
- Set up the required file structure for custom architectures.
- Register the model for MAX.
- Configure weight format conversions.
- Serve your model and make inference requests.
Set up your environment
Create a Python project and install the necessary dependencies:
- pip
- uv
- conda
- pixi
- Create a project folder:
mkdir my_model && cd my_model
mkdir my_model && cd my_model
- Create and activate a virtual environment:
python3 -m venv .venv/my_model \
&& source .venv/my_model/bin/activatepython3 -m venv .venv/my_model \
&& source .venv/my_model/bin/activate - Install the
modular
Python package:- Nightly
- Stable
pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--index-url https://dl.modular.com/public/nightly/python/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--index-url https://dl.modular.com/public/nightly/python/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/
- If you don't have it, install
uv
:curl -LsSf https://astral.sh/uv/install.sh | sh
curl -LsSf https://astral.sh/uv/install.sh | sh
Then restart your terminal to make
uv
accessible. - Create a project:
uv init my_model && cd my_model
uv init my_model && cd my_model
- Create and start a virtual environment:
uv venv && source .venv/bin/activate
uv venv && source .venv/bin/activate
- Install the
modular
Python package:- Nightly
- Stable
uv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--index-url https://dl.modular.com/public/nightly/python/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--index-url https://dl.modular.com/public/nightly/python/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/ \
--index-strategy unsafe-best-match
- If you don't have it, install conda. A common choice is with
brew
:brew install miniconda
brew install miniconda
- Initialize
conda
for shell interaction:conda init
conda init
If you're on a Mac, instead use:
conda init zsh
conda init zsh
Then restart your terminal for the changes to take effect.
- Create a project:
conda create -n my_model
conda create -n my_model
- Start the virtual environment:
conda activate my_model
conda activate my_model
- Install the
modular
conda package:- Nightly
- Stable
conda install -c conda-forge -c https://conda.modular.com/max-nightly/ modular
conda install -c conda-forge -c https://conda.modular.com/max-nightly/ modular
conda install -c conda-forge -c https://conda.modular.com/max/ modular
conda install -c conda-forge -c https://conda.modular.com/max/ modular
- If you don't have it, install
pixi
:curl -fsSL https://pixi.sh/install.sh | sh
curl -fsSL https://pixi.sh/install.sh | sh
Then restart your terminal for the changes to take effect.
- Create a project:
pixi init my_model \
-c https://conda.modular.com/max-nightly/ -c conda-forge \
&& cd my_modelpixi init my_model \
-c https://conda.modular.com/max-nightly/ -c conda-forge \
&& cd my_model - Install the
modular
conda package:- Nightly
- Stable
pixi add modular
pixi add modular
pixi add "modular=25.4"
pixi add "modular=25.4"
- Start the virtual environment:
pixi shell
pixi shell
Understand the architecture structure
Before creating your custom architecture, let's understand how MAX organizes model implementations. Each architecture follows a consistent structure that separates different concerns:
my_model/
├── __init__.py
├── arch.py
├── model.py
├── model_config.py
└── weight_adapters.py
my_model/
├── __init__.py
├── arch.py
├── model.py
├── model_config.py
└── weight_adapters.py
Here's what each file does:
-
__init__.py
: Makes your architecture discoverable. -
arch.py
: Registers your model with MAX, specifying supported encodings and capabilities. -
model.py
: Contains the core model implementation and computation graph logic. -
model_config.py
: Handles configuration parsing. -
weight_adapters.py
: Converts model weights from formats like SafeTensors or GGUF.
Implement the main model class
Starting by creating the model.py
with your core model implementation. This is
where you'll implement the main logic:
from typing import Dict, List, Optional
from max.pipelines.lib import PipelineModel
from max.graph import Graph, TensorType, DeviceRef, DType
from transformers import AutoConfig, AutoTokenizer
from .model_config import MyModelConfig
from .weight_adapters import convert_safetensor_state_dict
class MyModel(PipelineModel):
"""Main model class that implements your custom architecture."""
def __init__(self, config: MyModelConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
@classmethod
def from_huggingface(cls, model_path: str, **kwargs) -> "MyModel":
"""Create MyModel instance from Hugging Face model."""
# Load Hugging Face configuration
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# Convert to our internal configuration format
config = MyModelConfig.from_huggingface_config(hf_config)
return cls(config=config, **kwargs)
def build_graph(self) -> Graph:
"""Build the computation graph for your model.
This method defines how your model processes inputs and produces outputs.
You'll implement the actual neural network logic here.
"""
# Define input types for your model
input_types = [
TensorType(
DType.int64, # Token IDs
shape=["batch_size", "sequence_length"],
device=DeviceRef.GPU(),
)
]
# Create the computation graph
with Graph("my_model", input_types=input_types) as graph:
# Get graph inputs
(input_ids,) = graph.inputs
# TODO: Implement your model's forward pass here
# This is where you'd add your custom layers, attention mechanisms, etc.
# For now, we'll add a placeholder
# Example placeholder - replace with your actual model logic
output = input_ids # Placeholder
# Set graph outputs
graph.output(output)
return graph
from typing import Dict, List, Optional
from max.pipelines.lib import PipelineModel
from max.graph import Graph, TensorType, DeviceRef, DType
from transformers import AutoConfig, AutoTokenizer
from .model_config import MyModelConfig
from .weight_adapters import convert_safetensor_state_dict
class MyModel(PipelineModel):
"""Main model class that implements your custom architecture."""
def __init__(self, config: MyModelConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
@classmethod
def from_huggingface(cls, model_path: str, **kwargs) -> "MyModel":
"""Create MyModel instance from Hugging Face model."""
# Load Hugging Face configuration
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# Convert to our internal configuration format
config = MyModelConfig.from_huggingface_config(hf_config)
return cls(config=config, **kwargs)
def build_graph(self) -> Graph:
"""Build the computation graph for your model.
This method defines how your model processes inputs and produces outputs.
You'll implement the actual neural network logic here.
"""
# Define input types for your model
input_types = [
TensorType(
DType.int64, # Token IDs
shape=["batch_size", "sequence_length"],
device=DeviceRef.GPU(),
)
]
# Create the computation graph
with Graph("my_model", input_types=input_types) as graph:
# Get graph inputs
(input_ids,) = graph.inputs
# TODO: Implement your model's forward pass here
# This is where you'd add your custom layers, attention mechanisms, etc.
# For now, we'll add a placeholder
# Example placeholder - replace with your actual model logic
output = input_ids # Placeholder
# Set graph outputs
graph.output(output)
return graph
The build_graph()
method is where you'll implement your model's actual neural
network logic. While this example is a placeholder, you'll need to add your
specific layers, attention mechanisms, and forward pass logic based on your
architecture. For more information, see the Get started with MAX graphs
tutorial.
Define your architecture registration
Create the arch.py
file that tells MAX about your model's capabilities using
the
SupportedArchitecture
class.
from max.graph.weights import WeightsFormat
from max.nn.kv_cache import KVCacheStrategy
from max.pipelines.core import PipelineTask
from max.pipelines.lib import (
SupportedArchitecture,
SupportedEncoding,
TextTokenizer,
)
from . import weight_adapters
from .model import MyModel
my_model_arch = SupportedArchitecture(
name="MyModelForCausalLM",
example_repo_ids=[
"your-org/your-model-name", # Add example model repository IDs
],
default_encoding=SupportedEncoding.q4_k,
supported_encodings={
SupportedEncoding.q4_k: [KVCacheStrategy.PAGED],
SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED],
# Add other encodings your model supports
},
pipeline_model=MyModel,
tokenizer=TextTokenizer,
default_weights_format=WeightsFormat.safetensors,
multi_gpu_supported=True, # Set based on your implementation capabilities
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
# Add other weight formats if needed
},
task=PipelineTask.TEXT_GENERATION,
)
from max.graph.weights import WeightsFormat
from max.nn.kv_cache import KVCacheStrategy
from max.pipelines.core import PipelineTask
from max.pipelines.lib import (
SupportedArchitecture,
SupportedEncoding,
TextTokenizer,
)
from . import weight_adapters
from .model import MyModel
my_model_arch = SupportedArchitecture(
name="MyModelForCausalLM",
example_repo_ids=[
"your-org/your-model-name", # Add example model repository IDs
],
default_encoding=SupportedEncoding.q4_k,
supported_encodings={
SupportedEncoding.q4_k: [KVCacheStrategy.PAGED],
SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED],
# Add other encodings your model supports
},
pipeline_model=MyModel,
tokenizer=TextTokenizer,
default_weights_format=WeightsFormat.safetensors,
multi_gpu_supported=True, # Set based on your implementation capabilities
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
# Add other weight formats if needed
},
task=PipelineTask.TEXT_GENERATION,
)
The
SupportedArchitecture
configuration defines:
-
name
: Matches the model class name in your Hugging Face model's configuration, for exampleMyModelForCausalLM
. -
example_repo_ids
: List of repository IDs that use this architecture. This doesn't need to be an exhaustive list, but it should be a representative sample of the model variants you support. -
supported_encodings
: Which quantization formats and KV cache strategies your model supports. -
pipeline_model
: The main model class we'll implement next. -
task
: Specifies the pipeline task that the model supports.
Implement model configuration handling
Most models hosted on Hugging Face contain a config.json
file that defines
the model's architecture parameters, such as layer dimensions, attention heads,
and activation functions. However, MAX's internal graph-building system requires
these parameters in a specific format optimized for performance and graph
construction.
To bridge this gap, you'll create a translation layer that helps you handle situations where parameter names differ between Hugging Face and MAX, set sensible defaults for missing values, and ensure all the configuration data is in the right format for your MAX implementation.
Create model_config.py
to handle this configuration:
from dataclasses import dataclass
from typing import Any, Dict
from transformers import AutoConfig
@dataclass
class MyModelConfig:
"""Configuration class for your custom model.
This handles the translation between Hugging Face's config.json format
and your model's internal parameter requirements for MAX graph building.
"""
# Core model parameters
vocab_size: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
max_sequence_length: int
# Add your model-specific parameters here
intermediate_size: int = 11008
rms_norm_eps: float = 1e-6
@classmethod
def from_huggingface_config(cls, hf_config: AutoConfig) -> "MyModelConfig":
"""Create MyModelConfig from Hugging Face AutoConfig."""
return cls(
vocab_size=hf_config.vocab_size,
hidden_size=hf_config.hidden_size,
num_attention_heads=hf_config.num_attention_heads,
num_hidden_layers=hf_config.num_hidden_layers,
max_sequence_length=getattr(hf_config, "max_position_embeddings", 2048),
# Map other parameters from your Hugging Face config
intermediate_size=getattr(hf_config, "intermediate_size", 11008),
rms_norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6),
)
from dataclasses import dataclass
from typing import Any, Dict
from transformers import AutoConfig
@dataclass
class MyModelConfig:
"""Configuration class for your custom model.
This handles the translation between Hugging Face's config.json format
and your model's internal parameter requirements for MAX graph building.
"""
# Core model parameters
vocab_size: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
max_sequence_length: int
# Add your model-specific parameters here
intermediate_size: int = 11008
rms_norm_eps: float = 1e-6
@classmethod
def from_huggingface_config(cls, hf_config: AutoConfig) -> "MyModelConfig":
"""Create MyModelConfig from Hugging Face AutoConfig."""
return cls(
vocab_size=hf_config.vocab_size,
hidden_size=hf_config.hidden_size,
num_attention_heads=hf_config.num_attention_heads,
num_hidden_layers=hf_config.num_hidden_layers,
max_sequence_length=getattr(hf_config, "max_position_embeddings", 2048),
# Map other parameters from your Hugging Face config
intermediate_size=getattr(hf_config, "intermediate_size", 11008),
rms_norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6),
)
This configuration class acts as a critical bridge between Hugging Face's
standardized config.json
format and your model's specific implementation needs
within MAX's graph system.
Implement weight format conversion
Different model formats and frameworks store weights in varying layouts and
naming conventions that may not match your model's expected format. Create
weight_adapters.py
to handle weight format conversions:
from typing import Dict, Any
from max.graph.weights import WeightData
def convert_safetensor_state_dict(
state_dict: Dict[str, WeightData],
) -> Dict[str, WeightData]:
"""Convert SafeTensors weights to the format expected by your model.
Args:
state_dict: Raw weights loaded from SafeTensors format
Returns:
Converted weights ready for your model implementation
"""
converted_weights = {}
for key, weight in state_dict.items():
# Apply any necessary transformations to weight names or values
# This is where you handle differences between Hugging Face naming
# conventions and what your model expects
# Example: Remove prefixes that your model doesn't expect
clean_key = key.replace("model.", "")
# Example: Transpose weights if needed for your architecture
if "linear" in clean_key and len(weight.shape) == 2:
# Your model might expect different weight orientations
converted_weights[clean_key] = weight # Apply transpose if needed
else:
converted_weights[clean_key] = weight
return converted_weights
from typing import Dict, Any
from max.graph.weights import WeightData
def convert_safetensor_state_dict(
state_dict: Dict[str, WeightData],
) -> Dict[str, WeightData]:
"""Convert SafeTensors weights to the format expected by your model.
Args:
state_dict: Raw weights loaded from SafeTensors format
Returns:
Converted weights ready for your model implementation
"""
converted_weights = {}
for key, weight in state_dict.items():
# Apply any necessary transformations to weight names or values
# This is where you handle differences between Hugging Face naming
# conventions and what your model expects
# Example: Remove prefixes that your model doesn't expect
clean_key = key.replace("model.", "")
# Example: Transpose weights if needed for your architecture
if "linear" in clean_key and len(weight.shape) == 2:
# Your model might expect different weight orientations
converted_weights[clean_key] = weight # Apply transpose if needed
else:
converted_weights[clean_key] = weight
return converted_weights
Weight adapters ensure that regardless of how weights are stored (SafeTensors, GGUF, etc.), they get converted to the format your model expects.
Load your architecture
Create an __init__.py
file to make your architecture discoverable by MAX:
from .arch import my_model_arch
# MAX looks for this variable when loading custom architectures
ARCHITECTURES = [my_model_arch]
__all__ = ["my_model_arch", "ARCHITECTURES"]
from .arch import my_model_arch
# MAX looks for this variable when loading custom architectures
ARCHITECTURES = [my_model_arch]
__all__ = ["my_model_arch", "ARCHITECTURES"]
MAX automatically loads any architectures listed in the ARCHITECTURES
variable
when you specify your module with the
--custom-architectures
flag.
Test your custom architecture
You can now test your custom architecture using the --custom-architectures
flag:
max serve \
--model-path your-org/your-model-name \
--custom-architectures my_model
max serve \
--model-path your-org/your-model-name \
--custom-architectures my_model
The --model-path
flag tells MAX to use a specified model. You can specify the
model path to a Hugging Face model, or a local directory containing a model.
While the --custom-architectures
flag tells MAX to load custom architectures
from the specified Python module that we just built.
The server is ready when you see this message:
Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit)
Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit)
Now you can test your custom architecture. If you implemented an architecture to do text generation, you can send a request to that endpoint. For example:
- cURL
- Python
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "your-org/your-model-name",
"messages": [
{"role": "user", "content": "Hello! Can you help me with a simple task?"}
],
"max_tokens": 100
}'
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "your-org/your-model-name",
"messages": [
{"role": "user", "content": "Hello! Can you help me with a simple task?"}
],
"max_tokens": 100
}'
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="EMPTY", # Required by API but not used by MAX
)
response = client.chat.completions.create(
model="your-org/your-model-name",
messages=[
{"role": "user", "content": "Hello! Can you help me with a simple task?"}
],
max_tokens=100,
)
print(response.choices[0].message.content)
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="EMPTY", # Required by API but not used by MAX
)
response = client.chat.completions.create(
model="your-org/your-model-name",
messages=[
{"role": "user", "content": "Hello! Can you help me with a simple task?"}
],
max_tokens=100,
)
print(response.choices[0].message.content)
Next steps
Congratulations! You've successfully created a custom architecture for MAX
pipelines and served it with the max serve
command.
For implementation details, see our supported model architectures. Each subdirectory represents a different model family with its own implementation. This repository contains a variety of architectures you can use as the base for your own custom architecture.
Here are some areas to explore further:
Did this tutorial work for you?
Thank you! We'll create more content like this.
Thank you for helping us improve!