Model bring-up workflow
MAX organizes model support through architectures: Python module directories that define the compute graph, map configuration, and load weights. An architecture connects your Hugging Face checkpoint to the MAX pipeline and gives you hardware-optimized kernels, production batching, KV cache management, and multi-GPU distribution, with no serving layer to write yourself.
MAX ships with implementations for the most common transformer families: dense
transformers, mixture-of-experts variants, and sliding-window attention models.
Most publicly released checkpoints are structurally close enough to one of
those that adopting an existing architecture requires only weight renaming or a
few config field overrides. Some models require new layers that differ
minimally from an existing one, such as a different activation function or an
added normalization; start those from a copy of the closest layer in max.nn
rather than writing from scratch. When your model's compute graph differs
structurally from any of the available implementations, you write a new one.
This guide starts with the directory structure every custom architecture shares, then covers the two decisions that shape your implementation: which reference architecture to start from, and how far your model's compute graph departs from it.
Architecture directory structure
A custom architecture lives in a directory:
{your_model}/
├── layers/
├── __init__.py
├── arch.py
├── model_config.py
├── model.py
├── {model_name}.py
└── weight_adapters.pylayers/: holds custom layer implementations for components that differ fundamentally from anything inmax.nn.__init__.py: exposes theARCHITECTURESlist that MAX reads to register each entry.arch.py: registers the architecture with MAX, declaring its name (which must match thearchitectures[0]value in the model's Hugging Faceconfig.json), supported encodings, and the classes that implement the config, model, and weight adapter.model_config.py: translates Hugging Faceconfig.jsonfields into a typed configuration object MAX uses to build the compute graph.model.py: the pipeline model class registered inarch.py. When building on a reference architecture, this file subclasses the reference's pipeline model and overrides the attributes that differ.{model_name}.py: defines the transformer compute graph. Most models with unique sub-layers, an encoder, or a structure that differs from the generic transformer require this file.weight_adapters.py: translates Hugging Face checkpoint weight names to MAX weight slot names.
Pass the parent directory and package name to --custom-architectures when you
run max generate or max serve. MAX imports the package at startup, reads its
ARCHITECTURES list, and registers each entry before the model loads.
Choose a reference architecture
The reference architectures in MAX's codebase are complete implementations of
this package structure. Each one is a working example that serves production
traffic with max serve today. Browse the
architectures/ directory
alongside this guide to see how the files fit together in a finished
implementation.
Most large language models are variations on a small set of transformer blueprints. The attention mechanism, normalization placement, and residual structure in your model are almost certainly identical to an existing reference. Starting from the closest one means you inherit its compute graph, KV cache logic, and hardware optimization for everything you share, rather than reimplementing it.
Your implementation files define only what differs. model_config.py bridges
the Hugging Face config fields to MAX's representation. weight_adapters.py
remaps checkpoint keys to match MAX's naming. model.py overrides the compute
graph where it diverges, and arch.py registers the quantization encodings
your checkpoint supports. If your model's structure is identical to a reference
and only those surface details differ, those files are all you need.
When the differences go deeper than naming and configuration, the next section guides you to the right implementation depth. Choose based on structural features, not model name or family, for example:
llama3/: Standard dense transformer with GQA, RoPE, pre-RMSNorm, and SwiGLU. The right starting point for most models.mistral/: Use when your model uses sliding window (local) attention alternating with global attention layers.gemma3/: Use when your model uses a 5:1 local-to-global attention ratio with RMSNorm placed both before and after the attention module.deepseekV2/: Use when your model combines Mixture-of-Experts (MoE) routing with Multi-Head Latent Attention (MLA), which compresses the KV cache using low-rank projections.deepseekV3/: Use when your model uses MoE routing and MLA with auxiliary-loss-free load balancing.
The sections below build on each other. Config field mapping, weight naming, and KV cache integration are required for every LLM implementation. If your model's structure departs from the reference in bounded, enumerable ways, override the model components that differ. If the differences are structural, extend the compute graph to implement those components from scratch: a different attention mechanism, custom routing, or a positional encoding variant that can't be expressed as an attribute override.
Map config fields
MAX can't use a model's config.json from Hugging Face directly; it needs a
typed configuration object to build the compute graph. The general pattern is to
subclass the reference config class and override the fields where the Hugging
Face convention differs from MAX's representation. The Files and versions
tab on a Hugging Face model page includes the config.json your config class
needs to map. Use the config to identify which fields your subclass must
override and which MAX field names they correspond to:
{
"architectures": ["YourModelForCausalLM"],
"model_type": "your_model",
"hidden_size": 4096,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 32,
"vocab_size": 256000,
"logit_scale": 0.0625
}initialize_from_config() is where you read a raw Hugging Face field and
store a derived value. Override this method rather than initialize() when
extending a
reference config: the super() call runs the parent's full field mapping first,
then you add the fields that differ:
class YourConfig(Llama3Config):
logits_scaling: float = 16.0
@classmethod
def initialize_from_config(cls, pipeline_config, huggingface_config):
config = super().initialize_from_config(pipeline_config, huggingface_config)
logit_scale = getattr(huggingface_config, "logit_scale", 0.0625)
config.logits_scaling = 1.0 / logit_scale # 0.0625 → 16.0
return configMap weight names
You can't load Hugging Face weights directly into a MAX model because both
frameworks derive weight names as attribute paths through an nn.Module
hierarchy, and those hierarchies don't match. To load weights, you need a
function that translates Hugging Face checkpoint keys to MAX weight slot names.
The model's .safetensors.index.json file (in the Files and versions tab
on Hugging Face) shows the checkpoint's key names and shapes. Use it to
identify the prefix and naming patterns your weight adapter must translate:
model.embed_tokens.weight [vocab_size, hidden_size] BF16
model.layers.0.input_layernorm.weight [hidden_size] BF16
model.layers.0.self_attn.q_proj.weight [hidden_size, hidden_size] BF16
model.layers.0.mlp.gate_proj.weight [ffn_size, hidden_size] BF16
model.norm.weight [hidden_size] BF16The most common transformation is stripping a top-level prefix. A rename map
and a conversion function registered in arch.py are all you need to handle
it:
YOUR_MODEL_SAFETENSOR_MAPPING = {
"model.": "", # removes the top-level "model." prefix
}
# pass this function to SupportedArchitecture(weight_adapters=...) in arch.py
def convert_safetensor_state_dict(state_dict, huggingface_config, model_config):
...After MAX compiles the compute graph, it walks the model's state_dict() to
get a list of named weight slots. It calls your adapter to translate Hugging
Face checkpoint keys into slot names, then fills each slot by name. For all
weights to load correctly, your adapter must map every checkpoint key to a slot
name that exists in the model, and the attribute path your nn.Module
hierarchy assigns to each parameter must match what the adapter returns. For
quantized models, check the precision of weights and scales and cast to match
kernel expectations before returning them from the adapter.
MAX supports multiple weight formats through
load_weights():
.safetensors(Safetensors).gguf(GGUF)
Write a separate adapter function for each format you need to support, following
the same rename-map pattern. For the full API, see
max.graph.weights.
Connect to the KV cache
MAX pre-allocates the KV cache at startup from values your config provides,
before the first request arrives. Four values drive the allocation:
num_key_value_heads, head_dim, num_layers, and model_max_seq_len. Wrong
values produce a cache the wrong size: the model runs but generates incoherent
output rather than raising an error. The interface that carries these values is
ArchConfigWithAttentionKVCache,
which your config class must subclass and implement.
The four abstract properties your subclass must define are:
num_key_value_heads: int: total KV heads across all devices. Usenum_key_value_headsfrom the Hugging Face config, notnum_attention_heads. For grouped-query attention models the two differ; using the wrong field produces a cache sized for the wrong number of heads, silently.head_dim: int: dimensionality of each attention head.num_layers: int: number of hidden layers.model_max_seq_len: int: maximum sequence length the model supports.
MAX calls get_kv_params() on your config to size the cache. The standard
implementation uses a static construct_kv_params() helper that calls
kv_cache_config.to_params() with those values extracted from the Hugging Face
config. The following example shows the pattern all reference architectures
follow:
@staticmethod
def construct_kv_params(
huggingface_config: AutoConfig,
pipeline_config: PipelineConfig,
devices: list[DeviceRef],
kv_cache_config: KVCacheConfig,
cache_dtype: DType,
) -> KVCacheParams:
return kv_cache_config.to_params(
dtype=cache_dtype,
n_kv_heads=huggingface_config.num_key_value_heads,
head_dim=YourConfig.get_head_dim(huggingface_config),
num_layers=YourConfig.get_num_layers(huggingface_config),
devices=devices,
data_parallel_degree=pipeline_config.model.data_parallel_degree,
)Store the result as a KVCacheParams in a kv_params
field on your config class and return it from get_kv_params(). Then pass it
to each AttentionWithRope layer at construction time:
self.attn = AttentionWithRope(
kv_params=config.kv_params,
# ... remaining arguments
)With kv_params correctly configured, MAX dispatches fused attention kernels
(RoPE application, KV store, and flash attention) and manages paged memory
allocation. See Serve custom model
architectures for a complete
implementation with KVCacheConfig wired end-to-end.
Override model components
Overriding model components builds on config mapping and weight translation from
the previous sections. The addition is a model.py that subclasses the
reference model class and overrides only the components that differ. Hugging
Face model implementations already follow this pattern: in the transformers
library, many model classes subclass a family base and override only what
changed. The same approach applies to MAX's nn.Module hierarchy.
Subclass the reference model
model.py imports and subclasses the reference model class. Override only the
class attributes or methods that differ from the reference. For example, Qwen2
is structurally identical to Llama 3 except that it uses attention bias on the
projection layers. The entire model override is a single class attribute:
from max.pipelines.architectures.llama3.model import Llama3Model
class Qwen2Model(Llama3Model):
attention_bias: bool = TrueThe subclass inherits the full graph-building and compilation logic from the reference architecture. Verify that no weight slot your subclass introduces is absent from the checkpoint, and that no checkpoint key is left without a corresponding slot.
Extend the compute graph
Subclassing overrides attributes or methods within an existing graph structure.
Graph extension changes the structure itself: how many normalization layers a
block applies, whether attention and MLP share their input, which parameters are
learnable. What those changes look like depends entirely on your model's
architecture. max.nn provides the building blocks
(attention layers, normalization, MLP, embeddings), and most structural
differences can be expressed by composing from what's already there. The
examples below show two cases where a model's structure differed from the
reference and required graph-level changes.
Parallel decoder block: Llama 3 applies attention and MLP sequentially, each to its own residual stream. Some architectures instead apply both to the same normed input and add them together in one step:
# Sequential
x = x + self.self_attn(self.input_layernorm(x), ...)
x = x + self.mlp(self.post_attention_layernorm(x))
# Parallel decoder
normed = self.input_layernorm(x)
x = x + self.self_attn(normed, ...) + self.mlp(normed)The parallel form requires one input_layernorm per block (not two) and a
single residual add.
When building any nn.Module that holds a list of layers, store them in
LayerList rather than a plain
Python list. LayerList registers each layer in the nn.Module hierarchy
under a numeric index (0, 1, …), so weight names resolve to
layers.0.*, layers.1.*, and so on. A plain Python list doesn't register
its items in the hierarchy: MAX won't find those weights during loading, and
the layers run with uninitialized values.
Learnable LayerNorm: Some reference architectures expose norm type through
a norm_method parameter, but not all norm types are equivalent. Llama 3's
norm_method="layer_norm" creates a ConstantLayerNorm
with fixed gamma=1 and beta=0. If your model has learnable scale weights in
the checkpoint, passing this parameter loads them silently into empty slots
and the model generates incoherent text. Instantiate the norm directly when
you need learnable parameters:
from max.nn import LayerNorm
self.input_layernorm = LayerNorm(
dims=config.hidden_size,
eps=config.rms_norm_eps,
use_bias=False,
)LayerNorm(use_bias=False) has a learnable weight
attribute that the weight adapter maps from the checkpoint.
ConstantLayerNorm has none.
When you reuse subcomponents from a reference architecture (attention layers,
normalization, MLP), verify that those subcomponents don't define parameters
your model's checkpoint doesn't have. A subcomponent that creates weight slots
with no matching checkpoint keys means those parameters stay at their
initialized values. Call load_state_dict with strict=True to surface any
such mismatch as an error. Check the subcomponent's state_dict() against
your checkpoint's keys before reusing it.
Implement custom layers
layers/ holds the individual building blocks that model.py composes into
the full architecture. Each file in layers/ implements one component;
model.py describes how they connect, what the residual stream looks like, and
how they loop over layers.
When your architecture requires multiple custom components (custom attention,
MoE routing, a non-standard transformer block), implement each in a separate
file in layers/ rather than putting everything in model.py.
Next steps
With your architecture implemented, see Serve custom model architectures to run a generation, package your implementation, and load it at serve time.
If your model requires an operation that max.nn
doesn't provide and can't be composed from what's already there, see
Custom ops for the interface, registration, and
worked examples.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!