For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Model bring-up workflow
Model bring-up is the process of adapting a pretrained model from Hugging Face or PyTorch to run in MAX.
To bring a model to MAX, you must implement a model architecture. A model architecture is a Python package that defines the model's compute graph, maps its configuration, and loads its weights. Once implemented, you can serve the model with hardware-optimized kernels, production batching, KV cache management, and multi-GPU distribution.
This guide shows how to write a custom architecture by choosing a reference implementation, adapting its configuration and weights, and extending the compute graph where your model differs.
Architecture directory structureβ
The majority of text-generation model architectures follow this standard directory layout:
{your_model}/
βββ layers/
βββ __init__.py
βββ arch.py
βββ model_config.py
βββ model.py
βββ {model_name}.py
βββ weight_adapters.pyEach file serves a distinct role:
layers/: holds custom layer implementations for components that differ fundamentally from anything inmax.nn.__init__.py: exposes theARCHITECTURESlist that MAX reads to register each entry.arch.py: registers the architecture with MAX, declaring its name (which must match thearchitectures[0]value in the model's Hugging Faceconfig.json), supported encodings, and the classes that implement the config, model, and weight adapter.model_config.py: translates Hugging Faceconfig.jsonfields into a typed configuration object MAX uses to build the compute graph.model.py: the pipeline model class registered inarch.py. When building on a reference architecture, this file subclasses the reference's pipeline model and overrides the attributes that differ.{model_name}.py: defines the transformer compute graph. Most models with unique sub-layers, an encoder, or a structure that differs from the generic transformer require this file.weight_adapters.py: translates Hugging Face checkpoint weight names to MAX weight slot names.
The following sections go into detail about how each component works and whether they require customization.
Choose a reference architectureβ
The reference architectures in MAX's codebase are complete implementations of
this package structure. Each one is a working example that serves production
traffic with max serve today. Browse the
architectures/ directory
alongside this guide to see how the files fit together in a finished
implementation.
Most large language models are variations on a small set of transformer blueprints. The attention mechanism, normalization placement, and residual structure in your model are almost certainly identical to an existing reference. Starting from the closest one means you inherit its compute graph, KV cache logic, and hardware optimization for everything you share, rather than reimplementing it.
Your implementation files define only what differs. model_config.py bridges
the Hugging Face config fields to MAX's representation. weight_adapters.py
remaps checkpoint keys to match MAX's naming. model.py overrides the compute
graph where it diverges, and arch.py registers the quantization encodings
your checkpoint supports. If your model's structure is identical to a reference
and only those surface details differ, those files are all you need.
When the differences go deeper than naming and configuration, the next section guides you to the right implementation depth. Choose based on structural features, not model name or family, for example:
llama3/: Standard dense transformer with GQA, RoPE, pre-RMSNorm, and SwiGLU. The right starting point for most models.mistral/: Use when your model uses sliding window (local) attention alternating with global attention layers.gemma3/: Use when your model uses a 5:1 local-to-global attention ratio with RMSNorm placed both before and after the attention module.deepseekV2/: Use when your model combines Mixture-of-Experts (MoE) routing with Multi-Head Latent Attention (MLA), which compresses the KV cache using low-rank projections.deepseekV3/: Use when your model uses MoE routing and MLA with auxiliary-loss-free load balancing.
The sections below build on each other. Config field mapping, weight naming, and KV cache integration are required for every LLM implementation. If your model's structure departs from the reference in bounded, enumerable ways, override the model components that differ. If the differences are structural, extend the compute graph to implement those components from scratch: a different attention mechanism, custom routing, or a positional encoding variant that can't be expressed as an attribute override.
Map config fieldsβ
MAX can't use a model's config.json from Hugging Face directly; it needs a
typed configuration object to build the compute graph. The general pattern is to
subclass the reference config class and override the fields where the Hugging
Face convention differs from MAX's representation. The Files and versions
tab on a Hugging Face model page includes the config.json your config class
needs to map. For a field-by-field guide to reading a config and sorting its
fields into the work each implies, see
Read a model config. Use the config to
identify which fields your subclass must override and which MAX field names they
correspond to:
{
"architectures": ["YourModelForCausalLM"],
"model_type": "your_model",
"hidden_size": 4096,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 32,
"vocab_size": 256000,
"logit_scale": 0.0625
}initialize_from_config() is where you read a raw Hugging Face field and
store a derived value. Override this method rather than initialize() when
extending a
reference config: the super() call runs the parent's full field mapping first,
then you add the fields that differ:
class YourConfig(Llama3Config):
logits_scaling: float = 16.0
@classmethod
def initialize_from_config(cls, pipeline_config, huggingface_config):
config = super().initialize_from_config(pipeline_config, huggingface_config)
logit_scale = getattr(huggingface_config, "logit_scale", 0.0625)
config.logits_scaling = 1.0 / logit_scale # 0.0625 β 16.0
return configMap checkpoint weightsβ
When you construct your MAX Module, you declare each weight with a name,
shape, dtype, and quantization encoding. Checkpoint weights rarely match your
declared weights, so you write a weight adapter function. The function
converts the checkpoint into the weights your Module expects.
Renaming is the most common transformation needed in a weight adapter. However,
you might also need to cast dtypes, drop weights the Module doesn't use, and
transform tensor layouts. To learn how to implement these transformations, see
Write a weight adapter.
Connect to the KV cacheβ
MAX pre-allocates the KV cache at startup from values your config provides,
before the first request arrives. Four values drive the allocation:
num_key_value_heads, head_dim, num_layers, and model_max_seq_len. Wrong
values produce a cache the wrong size: the model runs but generates incoherent
output rather than raising an error. The interface that carries these values is
ArchConfigWithAttentionKVCache,
which your config class must subclass and implement.
The four abstract properties your subclass must define are:
num_key_value_heads: int: total KV heads across all devices. Usenum_key_value_headsfrom the Hugging Face config, notnum_attention_heads. For grouped-query attention models the two differ; using the wrong field produces a cache sized for the wrong number of heads, silently.head_dim: int: dimensionality of each attention head.num_layers: int: number of hidden layers.model_max_seq_len: int: maximum sequence length the model supports.
MAX calls get_kv_params() on your config to size the cache. The standard
implementation uses a static construct_kv_params() helper that calls
kv_cache_config.to_params() with those values extracted from the Hugging Face
config. The following example shows the pattern all reference architectures
follow:
@staticmethod
def construct_kv_params(
huggingface_config: AutoConfig,
pipeline_config: PipelineConfig,
devices: list[DeviceRef],
kv_cache_config: KVCacheConfig,
cache_dtype: DType,
) -> KVCacheParams:
return kv_cache_config.to_params(
dtype=cache_dtype,
n_kv_heads=huggingface_config.num_key_value_heads,
head_dim=YourConfig.get_head_dim(huggingface_config),
num_layers=YourConfig.get_num_layers(huggingface_config),
devices=devices,
data_parallel_degree=pipeline_config.model.data_parallel_degree,
)Store the result as a
KVCacheParams in a
kv_params field on your config class and return it from get_kv_params().
Then pass it to each
AttentionWithRope layer
at construction time:
self.attn = AttentionWithRope(
kv_params=config.kv_params,
# ... remaining arguments
)Most architectures return a flat KVCacheParams. Architectures with multiple or
hybrid KV caches (for example, separate sliding-window and full-attention
caches) return a
MultiKVCacheParams
tree instead, which the cache manager consumes through the shared
KVCacheParamInterface.
With kv_params correctly configured, MAX dispatches fused attention kernels
(RoPE application, KV store, and flash attention) and manages paged memory
allocation. See
Serve custom model architectures
for a complete implementation with
KVCacheConfig
wired end-to-end.
Override model componentsβ
Overriding model components builds on config mapping and weight translation from
the previous sections. The addition is a model.py that subclasses the
reference model class and overrides only the components that differ. Hugging
Face model implementations already follow this pattern: in the transformers
library, many model classes subclass a family base and override only what
changed. The same approach applies to MAX's nn.Module hierarchy.
Subclass the reference modelβ
model.py imports and subclasses the reference model class. Override only the
class attributes or methods that differ from the reference. For example, Qwen2
is structurally identical to Llama 3 except that it uses attention bias on the
projection layers. The entire model override is a single class attribute:
from max.pipelines.architectures.llama3.model import Llama3Model
class Qwen2Model(Llama3Model):
attention_bias: bool = TrueThe subclass inherits the full graph-building and compilation logic from the reference architecture. Verify that no weight slot your subclass introduces is absent from the checkpoint, and that no checkpoint key is left without a corresponding slot.
Extend the compute graphβ
Subclassing overrides attributes or methods within an existing graph structure.
Graph extension changes the structure itself: how many normalization layers a
block applies, whether attention and MLP share their input, which parameters are
learnable. What those changes look like depends entirely on your model's
architecture. max.nn provides the building blocks
(attention layers, normalization, MLP, embeddings), and most structural
differences can be expressed by composing from what's already there. The
examples below show two cases where a model's structure differed from the
reference and required graph-level changes.
Parallel decoder block: Llama 3 applies attention and MLP sequentially, each to its own residual stream. Some architectures instead apply both to the same normed input and add them together in one step:
# Sequential
x = x + self.self_attn(self.input_layernorm(x), ...)
x = x + self.mlp(self.post_attention_layernorm(x))
# Parallel decoder
normed = self.input_layernorm(x)
x = x + self.self_attn(normed, ...) + self.mlp(normed)The parallel form requires one input_layernorm per block (not two) and a
single residual add.
When building any nn.Module that holds a list of layers, store them in
LayerList rather than a plain
Python list. LayerList registers each layer in the nn.Module hierarchy
under a numeric index (0, 1, β¦), so weight names resolve to
layers.0.*, layers.1.*, and so on. A plain Python list doesn't register
its items in the hierarchy: MAX won't find those weights during loading, and
the layers run with uninitialized values.
Learnable LayerNorm: Some reference architectures expose norm type through
a norm_method parameter, but not all norm types are equivalent. Llama 3's
norm_method="layer_norm" creates a ConstantLayerNorm
with fixed gamma=1 and beta=0. If your model has learnable scale weights in
the checkpoint, passing this parameter loads them silently into empty slots
and the model generates incoherent text. Instantiate the norm directly when
you need learnable parameters:
from max.nn import LayerNorm
self.input_layernorm = LayerNorm(
dims=config.hidden_size,
eps=config.rms_norm_eps,
use_bias=False,
)LayerNorm(use_bias=False) has a learnable weight
attribute that the weight adapter maps from the checkpoint.
ConstantLayerNorm has none.
When you reuse subcomponents from a reference architecture (attention layers,
normalization, MLP), verify that those subcomponents don't define parameters
your model's checkpoint doesn't have. A subcomponent that creates weight slots
with no matching checkpoint keys means those parameters stay at their
initialized values. Call load_state_dict with strict=True to surface any
such mismatch as an error. Check the subcomponent's state_dict() against
your checkpoint's keys before reusing it.
Implement custom layersβ
layers/ holds the individual building blocks that model.py composes into
the full architecture. Each file in layers/ implements one component;
model.py describes how they connect, what the residual stream looks like, and
how they loop over layers.
When your architecture requires multiple custom components (custom attention,
MoE routing, a non-standard transformer block), implement each in a separate
file in layers/ rather than putting everything in model.py.
Expose the architectureβ
The arch.py file ties the rest of the package together. It should include one
or more
SupportedArchitecture
instances that point MAX at your config class, model class, weight adapters, and
more:
qwen2_arch = SupportedArchitecture(
name="Qwen2ForCausalLM",
task=PipelineTask.TEXT_GENERATION,
pipeline_model=Qwen2Model,
config=Qwen2Config,
supported_encodings={"float32", "bfloat16"},
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
},
# ... tokenizer, context_type, default_weights_format, ...
)After you create a SupportedArchitecture instance, expose it through your
package's __init__.py so MAX can register it when serving:
from .arch import qwen2_arch
ARCHITECTURES = [qwen2_arch]Most packages expose a single architecture, but you can list more than one when:
-
There are multiple tasks for the same model. The registry allows entries that share a name as long as their task differs (for example, a text generation architecture and an embedding architecture for the same model).
-
There are multiple Hugging Face class names. A single package with related variants whose
config.jsoncovers a model family might register entries for severalarchitectures[0]strings.
Next stepsβ
With your architecture implemented, see Serve custom model architectures to run a generation, package your implementation, and load it at serve time.
If your model requires an operation that max.nn
doesn't provide and can't be composed from what's already there, see
Custom ops for the interface, registration, and
worked examples.
Once your architecture loads and generates text, verify its forward pass
numerically against a PyTorch reference. See
Logit comparison for the full workflow, or use
the debug-model
skill to run that per-layer comparison with an AI coding agent.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!