Skip to main content

Model bring-up workflow

MAX organizes model support through architectures: Python module directories that define the compute graph, map configuration, and load weights. An architecture connects your Hugging Face checkpoint to the MAX pipeline and gives you hardware-optimized kernels, production batching, KV cache management, and multi-GPU distribution, with no serving layer to write yourself.

MAX ships with implementations for the most common transformer families: dense transformers, mixture-of-experts variants, and sliding-window attention models. Most publicly released checkpoints are structurally close enough to one of those that adopting an existing architecture requires only weight renaming or a few config field overrides. Some models require new layers that differ minimally from an existing one, such as a different activation function or an added normalization; start those from a copy of the closest layer in max.nn rather than writing from scratch. When your model's compute graph differs structurally from any of the available implementations, you write a new one.

This guide starts with the directory structure every custom architecture shares, then covers the two decisions that shape your implementation: which reference architecture to start from, and how far your model's compute graph departs from it.

Architecture directory structure

A custom architecture lives in a directory:

{your_model}/
├── layers/
├── __init__.py
├── arch.py
├── model_config.py
├── model.py
├── {model_name}.py
└── weight_adapters.py
  • layers/: holds custom layer implementations for components that differ fundamentally from anything in max.nn.
  • __init__.py: exposes the ARCHITECTURES list that MAX reads to register each entry.
  • arch.py: registers the architecture with MAX, declaring its name (which must match the architectures[0] value in the model's Hugging Face config.json), supported encodings, and the classes that implement the config, model, and weight adapter.
  • model_config.py: translates Hugging Face config.json fields into a typed configuration object MAX uses to build the compute graph.
  • model.py: the pipeline model class registered in arch.py. When building on a reference architecture, this file subclasses the reference's pipeline model and overrides the attributes that differ.
  • {model_name}.py: defines the transformer compute graph. Most models with unique sub-layers, an encoder, or a structure that differs from the generic transformer require this file.
  • weight_adapters.py: translates Hugging Face checkpoint weight names to MAX weight slot names.

Pass the parent directory and package name to --custom-architectures when you run max generate or max serve. MAX imports the package at startup, reads its ARCHITECTURES list, and registers each entry before the model loads.

Choose a reference architecture

The reference architectures in MAX's codebase are complete implementations of this package structure. Each one is a working example that serves production traffic with max serve today. Browse the architectures/ directory alongside this guide to see how the files fit together in a finished implementation.

Most large language models are variations on a small set of transformer blueprints. The attention mechanism, normalization placement, and residual structure in your model are almost certainly identical to an existing reference. Starting from the closest one means you inherit its compute graph, KV cache logic, and hardware optimization for everything you share, rather than reimplementing it.

Your implementation files define only what differs. model_config.py bridges the Hugging Face config fields to MAX's representation. weight_adapters.py remaps checkpoint keys to match MAX's naming. model.py overrides the compute graph where it diverges, and arch.py registers the quantization encodings your checkpoint supports. If your model's structure is identical to a reference and only those surface details differ, those files are all you need.

When the differences go deeper than naming and configuration, the next section guides you to the right implementation depth. Choose based on structural features, not model name or family, for example:

  • llama3/: Standard dense transformer with GQA, RoPE, pre-RMSNorm, and SwiGLU. The right starting point for most models.
  • mistral/: Use when your model uses sliding window (local) attention alternating with global attention layers.
  • gemma3/: Use when your model uses a 5:1 local-to-global attention ratio with RMSNorm placed both before and after the attention module.
  • deepseekV2/: Use when your model combines Mixture-of-Experts (MoE) routing with Multi-Head Latent Attention (MLA), which compresses the KV cache using low-rank projections.
  • deepseekV3/: Use when your model uses MoE routing and MLA with auxiliary-loss-free load balancing.

The sections below build on each other. Config field mapping, weight naming, and KV cache integration are required for every LLM implementation. If your model's structure departs from the reference in bounded, enumerable ways, override the model components that differ. If the differences are structural, extend the compute graph to implement those components from scratch: a different attention mechanism, custom routing, or a positional encoding variant that can't be expressed as an attribute override.

Map config fields

MAX can't use a model's config.json from Hugging Face directly; it needs a typed configuration object to build the compute graph. The general pattern is to subclass the reference config class and override the fields where the Hugging Face convention differs from MAX's representation. The Files and versions tab on a Hugging Face model page includes the config.json your config class needs to map. Use the config to identify which fields your subclass must override and which MAX field names they correspond to:

config.json
{
    "architectures": ["YourModelForCausalLM"],
    "model_type": "your_model",
    "hidden_size": 4096,
    "num_attention_heads": 32,
    "num_key_value_heads": 8,
    "num_hidden_layers": 32,
    "vocab_size": 256000,
    "logit_scale": 0.0625
}

initialize_from_config() is where you read a raw Hugging Face field and store a derived value. Override this method rather than initialize() when extending a reference config: the super() call runs the parent's full field mapping first, then you add the fields that differ:

model_config.py
class YourConfig(Llama3Config):
    logits_scaling: float = 16.0

    @classmethod
    def initialize_from_config(cls, pipeline_config, huggingface_config):
        config = super().initialize_from_config(pipeline_config, huggingface_config)
        logit_scale = getattr(huggingface_config, "logit_scale", 0.0625)
        config.logits_scaling = 1.0 / logit_scale  # 0.0625 → 16.0
        return config

Map weight names

You can't load Hugging Face weights directly into a MAX model because both frameworks derive weight names as attribute paths through an nn.Module hierarchy, and those hierarchies don't match. To load weights, you need a function that translates Hugging Face checkpoint keys to MAX weight slot names.

The model's .safetensors.index.json file (in the Files and versions tab on Hugging Face) shows the checkpoint's key names and shapes. Use it to identify the prefix and naming patterns your weight adapter must translate:

model.embed_tokens.weight                    [vocab_size, hidden_size]  BF16
model.layers.0.input_layernorm.weight        [hidden_size]              BF16
model.layers.0.self_attn.q_proj.weight       [hidden_size, hidden_size] BF16
model.layers.0.mlp.gate_proj.weight          [ffn_size, hidden_size]    BF16
model.norm.weight                            [hidden_size]              BF16

The most common transformation is stripping a top-level prefix. A rename map and a conversion function registered in arch.py are all you need to handle it:

weight_adapters.py
YOUR_MODEL_SAFETENSOR_MAPPING = {
    "model.": "",  # removes the top-level "model." prefix
}

# pass this function to SupportedArchitecture(weight_adapters=...) in arch.py
def convert_safetensor_state_dict(state_dict, huggingface_config, model_config):
    ...

After MAX compiles the compute graph, it walks the model's state_dict() to get a list of named weight slots. It calls your adapter to translate Hugging Face checkpoint keys into slot names, then fills each slot by name. For all weights to load correctly, your adapter must map every checkpoint key to a slot name that exists in the model, and the attribute path your nn.Module hierarchy assigns to each parameter must match what the adapter returns. For quantized models, check the precision of weights and scales and cast to match kernel expectations before returning them from the adapter.

MAX supports multiple weight formats through load_weights():

  • .safetensors (Safetensors)
  • .gguf (GGUF)

Write a separate adapter function for each format you need to support, following the same rename-map pattern. For the full API, see max.graph.weights.

Connect to the KV cache

MAX pre-allocates the KV cache at startup from values your config provides, before the first request arrives. Four values drive the allocation: num_key_value_heads, head_dim, num_layers, and model_max_seq_len. Wrong values produce a cache the wrong size: the model runs but generates incoherent output rather than raising an error. The interface that carries these values is ArchConfigWithAttentionKVCache, which your config class must subclass and implement.

The four abstract properties your subclass must define are:

  • num_key_value_heads: int: total KV heads across all devices. Use num_key_value_heads from the Hugging Face config, not num_attention_heads. For grouped-query attention models the two differ; using the wrong field produces a cache sized for the wrong number of heads, silently.
  • head_dim: int: dimensionality of each attention head.
  • num_layers: int: number of hidden layers.
  • model_max_seq_len: int: maximum sequence length the model supports.

MAX calls get_kv_params() on your config to size the cache. The standard implementation uses a static construct_kv_params() helper that calls kv_cache_config.to_params() with those values extracted from the Hugging Face config. The following example shows the pattern all reference architectures follow:

model_config.py
@staticmethod
def construct_kv_params(
    huggingface_config: AutoConfig,
    pipeline_config: PipelineConfig,
    devices: list[DeviceRef],
    kv_cache_config: KVCacheConfig,
    cache_dtype: DType,
) -> KVCacheParams:
    return kv_cache_config.to_params(
        dtype=cache_dtype,
        n_kv_heads=huggingface_config.num_key_value_heads,
        head_dim=YourConfig.get_head_dim(huggingface_config),
        num_layers=YourConfig.get_num_layers(huggingface_config),
        devices=devices,
        data_parallel_degree=pipeline_config.model.data_parallel_degree,
    )

Store the result as a KVCacheParams in a kv_params field on your config class and return it from get_kv_params(). Then pass it to each AttentionWithRope layer at construction time:

model.py
self.attn = AttentionWithRope(
    kv_params=config.kv_params,
    # ... remaining arguments
)

With kv_params correctly configured, MAX dispatches fused attention kernels (RoPE application, KV store, and flash attention) and manages paged memory allocation. See Serve custom model architectures for a complete implementation with KVCacheConfig wired end-to-end.

Override model components

Overriding model components builds on config mapping and weight translation from the previous sections. The addition is a model.py that subclasses the reference model class and overrides only the components that differ. Hugging Face model implementations already follow this pattern: in the transformers library, many model classes subclass a family base and override only what changed. The same approach applies to MAX's nn.Module hierarchy.

Subclass the reference model

model.py imports and subclasses the reference model class. Override only the class attributes or methods that differ from the reference. For example, Qwen2 is structurally identical to Llama 3 except that it uses attention bias on the projection layers. The entire model override is a single class attribute:

model.py
from max.pipelines.architectures.llama3.model import Llama3Model

class Qwen2Model(Llama3Model):
    attention_bias: bool = True

The subclass inherits the full graph-building and compilation logic from the reference architecture. Verify that no weight slot your subclass introduces is absent from the checkpoint, and that no checkpoint key is left without a corresponding slot.

Extend the compute graph

Subclassing overrides attributes or methods within an existing graph structure. Graph extension changes the structure itself: how many normalization layers a block applies, whether attention and MLP share their input, which parameters are learnable. What those changes look like depends entirely on your model's architecture. max.nn provides the building blocks (attention layers, normalization, MLP, embeddings), and most structural differences can be expressed by composing from what's already there. The examples below show two cases where a model's structure differed from the reference and required graph-level changes.

Parallel decoder block: Llama 3 applies attention and MLP sequentially, each to its own residual stream. Some architectures instead apply both to the same normed input and add them together in one step:

model.py
# Sequential
x = x + self.self_attn(self.input_layernorm(x), ...)
x = x + self.mlp(self.post_attention_layernorm(x))

# Parallel decoder
normed = self.input_layernorm(x)
x = x + self.self_attn(normed, ...) + self.mlp(normed)

The parallel form requires one input_layernorm per block (not two) and a single residual add.

When building any nn.Module that holds a list of layers, store them in LayerList rather than a plain Python list. LayerList registers each layer in the nn.Module hierarchy under a numeric index (0, 1, …), so weight names resolve to layers.0.*, layers.1.*, and so on. A plain Python list doesn't register its items in the hierarchy: MAX won't find those weights during loading, and the layers run with uninitialized values.

Learnable LayerNorm: Some reference architectures expose norm type through a norm_method parameter, but not all norm types are equivalent. Llama 3's norm_method="layer_norm" creates a ConstantLayerNorm with fixed gamma=1 and beta=0. If your model has learnable scale weights in the checkpoint, passing this parameter loads them silently into empty slots and the model generates incoherent text. Instantiate the norm directly when you need learnable parameters:

model.py
from max.nn import LayerNorm

self.input_layernorm = LayerNorm(
    dims=config.hidden_size,
    eps=config.rms_norm_eps,
    use_bias=False,
)

LayerNorm(use_bias=False) has a learnable weight attribute that the weight adapter maps from the checkpoint. ConstantLayerNorm has none.

When you reuse subcomponents from a reference architecture (attention layers, normalization, MLP), verify that those subcomponents don't define parameters your model's checkpoint doesn't have. A subcomponent that creates weight slots with no matching checkpoint keys means those parameters stay at their initialized values. Call load_state_dict with strict=True to surface any such mismatch as an error. Check the subcomponent's state_dict() against your checkpoint's keys before reusing it.

Implement custom layers

layers/ holds the individual building blocks that model.py composes into the full architecture. Each file in layers/ implements one component; model.py describes how they connect, what the residual stream looks like, and how they loop over layers.

When your architecture requires multiple custom components (custom attention, MoE routing, a non-standard transformer block), implement each in a separate file in layers/ rather than putting everything in model.py.

Next steps

With your architecture implemented, see Serve custom model architectures to run a generation, package your implementation, and load it at serve time.

If your model requires an operation that max.nn doesn't provide and can't be composed from what's already there, see Custom ops for the interface, registration, and worked examples.

Was this page helpful?