Skip to main content

Serve custom model architectures

MAX comes with built-in support for popular model architectures like Gemma3ForCausalLM, Qwen2ForCausalLM, and LlamaForCausalLM, so you can instantly deploy them by passing a specific Hugging Face model name to the max serve command (explore our model repo). You can also use MAX to serve a custom model architecture with the max serve command, which provides an OpenAI-compatible API.

In this tutorial, you'll build a custom architecture for a model called MyModel using our Python API, implement components for MAX, and serve your model with an OpenAI-compatible endpoint. By the end of this tutorial, you'll understand how to:

  • Set up the required file structure for custom architectures.
  • Register the model for MAX.
  • Configure weight format conversions.
  • Serve your model and make inference requests.

Set up your environment

Create a Python project and install the necessary dependencies:

  1. Create a project folder:
    mkdir my_model && cd my_model
    mkdir my_model && cd my_model
  2. Create and activate a virtual environment:
    python3 -m venv .venv/my_model \
    && source .venv/my_model/bin/activate
    python3 -m venv .venv/my_model \
    && source .venv/my_model/bin/activate
  3. Install the modular Python package:
    pip install modular \
    --extra-index-url https://download.pytorch.org/whl/cpu \
    --index-url https://dl.modular.com/public/nightly/python/simple/
    pip install modular \
    --extra-index-url https://download.pytorch.org/whl/cpu \
    --index-url https://dl.modular.com/public/nightly/python/simple/

Understand the architecture structure

Before creating your custom architecture, let's understand how MAX organizes model implementations. Each architecture follows a consistent structure that separates different concerns:

my_model/
├── __init__.py
├── arch.py
├── model.py
├── model_config.py
└── weight_adapters.py
my_model/
├── __init__.py
├── arch.py
├── model.py
├── model_config.py
└── weight_adapters.py

Here's what each file does:

  • __init__.py: Makes your architecture discoverable.

  • arch.py: Registers your model with MAX, specifying supported encodings and capabilities.

  • model.py: Contains the core model implementation and computation graph logic.

  • model_config.py: Handles configuration parsing.

  • weight_adapters.py: Converts model weights from formats like SafeTensors or GGUF.

Implement the main model class

Starting by creating the model.py with your core model implementation. This is where you'll implement the main logic:

model.py
from typing import Dict, List, Optional
from max.pipelines.lib import PipelineModel
from max.graph import Graph, TensorType, DeviceRef, DType
from transformers import AutoConfig, AutoTokenizer

from .model_config import MyModelConfig
from .weight_adapters import convert_safetensor_state_dict


class MyModel(PipelineModel):
"""Main model class that implements your custom architecture."""

def __init__(self, config: MyModelConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config

@classmethod
def from_huggingface(cls, model_path: str, **kwargs) -> "MyModel":
"""Create MyModel instance from Hugging Face model."""
# Load Hugging Face configuration
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

# Convert to our internal configuration format
config = MyModelConfig.from_huggingface_config(hf_config)

return cls(config=config, **kwargs)

def build_graph(self) -> Graph:
"""Build the computation graph for your model.

This method defines how your model processes inputs and produces outputs.
You'll implement the actual neural network logic here.
"""
# Define input types for your model
input_types = [
TensorType(
DType.int64, # Token IDs
shape=["batch_size", "sequence_length"],
device=DeviceRef.GPU(),
)
]

# Create the computation graph
with Graph("my_model", input_types=input_types) as graph:
# Get graph inputs
(input_ids,) = graph.inputs

# TODO: Implement your model's forward pass here
# This is where you'd add your custom layers, attention mechanisms, etc.
# For now, we'll add a placeholder

# Example placeholder - replace with your actual model logic
output = input_ids # Placeholder

# Set graph outputs
graph.output(output)

return graph
from typing import Dict, List, Optional
from max.pipelines.lib import PipelineModel
from max.graph import Graph, TensorType, DeviceRef, DType
from transformers import AutoConfig, AutoTokenizer

from .model_config import MyModelConfig
from .weight_adapters import convert_safetensor_state_dict


class MyModel(PipelineModel):
"""Main model class that implements your custom architecture."""

def __init__(self, config: MyModelConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config

@classmethod
def from_huggingface(cls, model_path: str, **kwargs) -> "MyModel":
"""Create MyModel instance from Hugging Face model."""
# Load Hugging Face configuration
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

# Convert to our internal configuration format
config = MyModelConfig.from_huggingface_config(hf_config)

return cls(config=config, **kwargs)

def build_graph(self) -> Graph:
"""Build the computation graph for your model.

This method defines how your model processes inputs and produces outputs.
You'll implement the actual neural network logic here.
"""
# Define input types for your model
input_types = [
TensorType(
DType.int64, # Token IDs
shape=["batch_size", "sequence_length"],
device=DeviceRef.GPU(),
)
]

# Create the computation graph
with Graph("my_model", input_types=input_types) as graph:
# Get graph inputs
(input_ids,) = graph.inputs

# TODO: Implement your model's forward pass here
# This is where you'd add your custom layers, attention mechanisms, etc.
# For now, we'll add a placeholder

# Example placeholder - replace with your actual model logic
output = input_ids # Placeholder

# Set graph outputs
graph.output(output)

return graph

The build_graph() method is where you'll implement your model's actual neural network logic. While this example is a placeholder, you'll need to add your specific layers, attention mechanisms, and forward pass logic based on your architecture. For more information, see the Get started with MAX graphs tutorial.

Define your architecture registration

Create the arch.py file that tells MAX about your model's capabilities using the SupportedArchitecture class.

arch.py
from max.graph.weights import WeightsFormat
from max.nn.kv_cache import KVCacheStrategy
from max.pipelines.core import PipelineTask
from max.pipelines.lib import (
SupportedArchitecture,
SupportedEncoding,
TextTokenizer,
)

from . import weight_adapters
from .model import MyModel

my_model_arch = SupportedArchitecture(
name="MyModelForCausalLM",
example_repo_ids=[
"your-org/your-model-name", # Add example model repository IDs
],
default_encoding=SupportedEncoding.q4_k,
supported_encodings={
SupportedEncoding.q4_k: [KVCacheStrategy.PAGED],
SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED],
# Add other encodings your model supports
},
pipeline_model=MyModel,
tokenizer=TextTokenizer,
default_weights_format=WeightsFormat.safetensors,
multi_gpu_supported=True, # Set based on your implementation capabilities
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
# Add other weight formats if needed
},
task=PipelineTask.TEXT_GENERATION,
)
from max.graph.weights import WeightsFormat
from max.nn.kv_cache import KVCacheStrategy
from max.pipelines.core import PipelineTask
from max.pipelines.lib import (
SupportedArchitecture,
SupportedEncoding,
TextTokenizer,
)

from . import weight_adapters
from .model import MyModel

my_model_arch = SupportedArchitecture(
name="MyModelForCausalLM",
example_repo_ids=[
"your-org/your-model-name", # Add example model repository IDs
],
default_encoding=SupportedEncoding.q4_k,
supported_encodings={
SupportedEncoding.q4_k: [KVCacheStrategy.PAGED],
SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED],
# Add other encodings your model supports
},
pipeline_model=MyModel,
tokenizer=TextTokenizer,
default_weights_format=WeightsFormat.safetensors,
multi_gpu_supported=True, # Set based on your implementation capabilities
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
# Add other weight formats if needed
},
task=PipelineTask.TEXT_GENERATION,
)

The SupportedArchitecture configuration defines:

  • name: Matches the model class name in your Hugging Face model's configuration, for example MyModelForCausalLM.

  • example_repo_ids: List of repository IDs that use this architecture. This doesn't need to be an exhaustive list, but it should be a representative sample of the model variants you support.

  • supported_encodings: Which quantization formats and KV cache strategies your model supports.

  • pipeline_model: The main model class we'll implement next.

  • task: Specifies the pipeline task that the model supports.

Implement model configuration handling

Most models hosted on Hugging Face contain a config.json file that defines the model's architecture parameters, such as layer dimensions, attention heads, and activation functions. However, MAX's internal graph-building system requires these parameters in a specific format optimized for performance and graph construction.

To bridge this gap, you'll create a translation layer that helps you handle situations where parameter names differ between Hugging Face and MAX, set sensible defaults for missing values, and ensure all the configuration data is in the right format for your MAX implementation.

Create model_config.py to handle this configuration:

model_config.py
from dataclasses import dataclass
from typing import Any, Dict

from transformers import AutoConfig


@dataclass
class MyModelConfig:
"""Configuration class for your custom model.

This handles the translation between Hugging Face's config.json format
and your model's internal parameter requirements for MAX graph building.
"""

# Core model parameters
vocab_size: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
max_sequence_length: int

# Add your model-specific parameters here
intermediate_size: int = 11008
rms_norm_eps: float = 1e-6

@classmethod
def from_huggingface_config(cls, hf_config: AutoConfig) -> "MyModelConfig":
"""Create MyModelConfig from Hugging Face AutoConfig."""
return cls(
vocab_size=hf_config.vocab_size,
hidden_size=hf_config.hidden_size,
num_attention_heads=hf_config.num_attention_heads,
num_hidden_layers=hf_config.num_hidden_layers,
max_sequence_length=getattr(hf_config, "max_position_embeddings", 2048),
# Map other parameters from your Hugging Face config
intermediate_size=getattr(hf_config, "intermediate_size", 11008),
rms_norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6),
)
from dataclasses import dataclass
from typing import Any, Dict

from transformers import AutoConfig


@dataclass
class MyModelConfig:
"""Configuration class for your custom model.

This handles the translation between Hugging Face's config.json format
and your model's internal parameter requirements for MAX graph building.
"""

# Core model parameters
vocab_size: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
max_sequence_length: int

# Add your model-specific parameters here
intermediate_size: int = 11008
rms_norm_eps: float = 1e-6

@classmethod
def from_huggingface_config(cls, hf_config: AutoConfig) -> "MyModelConfig":
"""Create MyModelConfig from Hugging Face AutoConfig."""
return cls(
vocab_size=hf_config.vocab_size,
hidden_size=hf_config.hidden_size,
num_attention_heads=hf_config.num_attention_heads,
num_hidden_layers=hf_config.num_hidden_layers,
max_sequence_length=getattr(hf_config, "max_position_embeddings", 2048),
# Map other parameters from your Hugging Face config
intermediate_size=getattr(hf_config, "intermediate_size", 11008),
rms_norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6),
)

This configuration class acts as a critical bridge between Hugging Face's standardized config.json format and your model's specific implementation needs within MAX's graph system.

Implement weight format conversion

Different model formats and frameworks store weights in varying layouts and naming conventions that may not match your model's expected format. Create weight_adapters.py to handle weight format conversions:

weight_adapters.py
from typing import Dict, Any
from max.graph.weights import WeightData


def convert_safetensor_state_dict(
state_dict: Dict[str, WeightData],
) -> Dict[str, WeightData]:
"""Convert SafeTensors weights to the format expected by your model.

Args:
state_dict: Raw weights loaded from SafeTensors format

Returns:
Converted weights ready for your model implementation
"""
converted_weights = {}

for key, weight in state_dict.items():
# Apply any necessary transformations to weight names or values
# This is where you handle differences between Hugging Face naming
# conventions and what your model expects

# Example: Remove prefixes that your model doesn't expect
clean_key = key.replace("model.", "")

# Example: Transpose weights if needed for your architecture
if "linear" in clean_key and len(weight.shape) == 2:
# Your model might expect different weight orientations
converted_weights[clean_key] = weight # Apply transpose if needed
else:
converted_weights[clean_key] = weight

return converted_weights
from typing import Dict, Any
from max.graph.weights import WeightData


def convert_safetensor_state_dict(
state_dict: Dict[str, WeightData],
) -> Dict[str, WeightData]:
"""Convert SafeTensors weights to the format expected by your model.

Args:
state_dict: Raw weights loaded from SafeTensors format

Returns:
Converted weights ready for your model implementation
"""
converted_weights = {}

for key, weight in state_dict.items():
# Apply any necessary transformations to weight names or values
# This is where you handle differences between Hugging Face naming
# conventions and what your model expects

# Example: Remove prefixes that your model doesn't expect
clean_key = key.replace("model.", "")

# Example: Transpose weights if needed for your architecture
if "linear" in clean_key and len(weight.shape) == 2:
# Your model might expect different weight orientations
converted_weights[clean_key] = weight # Apply transpose if needed
else:
converted_weights[clean_key] = weight

return converted_weights

Weight adapters ensure that regardless of how weights are stored (SafeTensors, GGUF, etc.), they get converted to the format your model expects.

Load your architecture

Create an __init__.py file to make your architecture discoverable by MAX:

__init__.py
from .arch import my_model_arch

# MAX looks for this variable when loading custom architectures
ARCHITECTURES = [my_model_arch]

__all__ = ["my_model_arch", "ARCHITECTURES"]
from .arch import my_model_arch

# MAX looks for this variable when loading custom architectures
ARCHITECTURES = [my_model_arch]

__all__ = ["my_model_arch", "ARCHITECTURES"]

MAX automatically loads any architectures listed in the ARCHITECTURES variable when you specify your module with the --custom-architectures flag.

Test your custom architecture

You can now test your custom architecture using the --custom-architectures flag:

max serve \
--model-path your-org/your-model-name \
--custom-architectures my_model
max serve \
--model-path your-org/your-model-name \
--custom-architectures my_model

The --model-path flag tells MAX to use a specified model. You can specify the model path to a Hugging Face model, or a local directory containing a model. While the --custom-architectures flag tells MAX to load custom architectures from the specified Python module that we just built.

The server is ready when you see this message:

Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit)
Server ready on http://0.0.0.0:8000 (Press CTRL+C to quit)

Now you can test your custom architecture. If you implemented an architecture to do text generation, you can send a request to that endpoint. For example:

curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "your-org/your-model-name",
"messages": [
{"role": "user", "content": "Hello! Can you help me with a simple task?"}
],
"max_tokens": 100
}'
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "your-org/your-model-name",
"messages": [
{"role": "user", "content": "Hello! Can you help me with a simple task?"}
],
"max_tokens": 100
}'

Next steps

Congratulations! You've successfully created a custom architecture for MAX pipelines and served it with the max serve command.

For implementation details, see our supported model architectures. Each subdirectory represents a different model family with its own implementation. This repository contains a variety of architectures you can use as the base for your own custom architecture.

Here are some areas to explore further:

Did this tutorial work for you?