IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Model bring-up workflow

Model bring-up is the process of adapting a pretrained model from Hugging Face or PyTorch to run in MAX.

To bring a model to MAX, you must implement a model architecture. A model architecture is a Python package that defines the model's compute graph, maps its configuration, and loads its weights. Once implemented, you can serve the model with hardware-optimized kernels, production batching, KV cache management, and multi-GPU distribution.

This guide shows how to write a custom architecture by choosing a reference implementation, adapting its configuration and weights, and extending the compute graph where your model differs.

Architecture directory structure​

The majority of text-generation model architectures follow this standard directory layout:

{your_model}/
β”œβ”€β”€ layers/
β”œβ”€β”€ __init__.py
β”œβ”€β”€ arch.py
β”œβ”€β”€ model_config.py
β”œβ”€β”€ model.py
β”œβ”€β”€ {model_name}.py
└── weight_adapters.py

Each file serves a distinct role:

  • layers/: holds custom layer implementations for components that differ fundamentally from anything in max.nn.
  • __init__.py: exposes the ARCHITECTURES list that MAX reads to register each entry.
  • arch.py: registers the architecture with MAX, declaring its name (which must match the architectures[0] value in the model's Hugging Face config.json), supported encodings, and the classes that implement the config, model, and weight adapter.
  • model_config.py: translates Hugging Face config.json fields into a typed configuration object MAX uses to build the compute graph.
  • model.py: the pipeline model class registered in arch.py. When building on a reference architecture, this file subclasses the reference's pipeline model and overrides the attributes that differ.
  • {model_name}.py: defines the transformer compute graph. Most models with unique sub-layers, an encoder, or a structure that differs from the generic transformer require this file.
  • weight_adapters.py: translates Hugging Face checkpoint weight names to MAX weight slot names.

The following sections go into detail about how each component works and whether they require customization.

Choose a reference architecture​

The reference architectures in MAX's codebase are complete implementations of this package structure. Each one is a working example that serves production traffic with max serve today. Browse the architectures/ directory alongside this guide to see how the files fit together in a finished implementation.

Most large language models are variations on a small set of transformer blueprints. The attention mechanism, normalization placement, and residual structure in your model are almost certainly identical to an existing reference. Starting from the closest one means you inherit its compute graph, KV cache logic, and hardware optimization for everything you share, rather than reimplementing it.

Your implementation files define only what differs. model_config.py bridges the Hugging Face config fields to MAX's representation. weight_adapters.py remaps checkpoint keys to match MAX's naming. model.py overrides the compute graph where it diverges, and arch.py registers the quantization encodings your checkpoint supports. If your model's structure is identical to a reference and only those surface details differ, those files are all you need.

When the differences go deeper than naming and configuration, the next section guides you to the right implementation depth. Choose based on structural features, not model name or family, for example:

  • llama3/: Standard dense transformer with GQA, RoPE, pre-RMSNorm, and SwiGLU. The right starting point for most models.
  • mistral/: Use when your model uses sliding window (local) attention alternating with global attention layers.
  • gemma3/: Use when your model uses a 5:1 local-to-global attention ratio with RMSNorm placed both before and after the attention module.
  • deepseekV2/: Use when your model combines Mixture-of-Experts (MoE) routing with Multi-Head Latent Attention (MLA), which compresses the KV cache using low-rank projections.
  • deepseekV3/: Use when your model uses MoE routing and MLA with auxiliary-loss-free load balancing.

The sections below build on each other. Config field mapping, weight naming, and KV cache integration are required for every LLM implementation. If your model's structure departs from the reference in bounded, enumerable ways, override the model components that differ. If the differences are structural, extend the compute graph to implement those components from scratch: a different attention mechanism, custom routing, or a positional encoding variant that can't be expressed as an attribute override.

Map config fields​

MAX can't use a model's config.json from Hugging Face directly; it needs a typed configuration object to build the compute graph. The general pattern is to subclass the reference config class and override the fields where the Hugging Face convention differs from MAX's representation. The Files and versions tab on a Hugging Face model page includes the config.json your config class needs to map. For a field-by-field guide to reading a config and sorting its fields into the work each implies, see Read a model config. Use the config to identify which fields your subclass must override and which MAX field names they correspond to:

config.json
{
    "architectures": ["YourModelForCausalLM"],
    "model_type": "your_model",
    "hidden_size": 4096,
    "num_attention_heads": 32,
    "num_key_value_heads": 8,
    "num_hidden_layers": 32,
    "vocab_size": 256000,
    "logit_scale": 0.0625
}

initialize_from_config() is where you read a raw Hugging Face field and store a derived value. Override this method rather than initialize() when extending a reference config: the super() call runs the parent's full field mapping first, then you add the fields that differ:

model_config.py
class YourConfig(Llama3Config):
    logits_scaling: float = 16.0

    @classmethod
    def initialize_from_config(cls, pipeline_config, huggingface_config):
        config = super().initialize_from_config(pipeline_config, huggingface_config)
        logit_scale = getattr(huggingface_config, "logit_scale", 0.0625)
        config.logits_scaling = 1.0 / logit_scale  # 0.0625 β†’ 16.0
        return config

Map checkpoint weights​

When you construct your MAX Module, you declare each weight with a name, shape, dtype, and quantization encoding. Checkpoint weights rarely match your declared weights, so you write a weight adapter function. The function converts the checkpoint into the weights your Module expects.

Renaming is the most common transformation needed in a weight adapter. However, you might also need to cast dtypes, drop weights the Module doesn't use, and transform tensor layouts. To learn how to implement these transformations, see Write a weight adapter.

Connect to the KV cache​

MAX pre-allocates the KV cache at startup from values your config provides, before the first request arrives. Four values drive the allocation: num_key_value_heads, head_dim, num_layers, and model_max_seq_len. Wrong values produce a cache the wrong size: the model runs but generates incoherent output rather than raising an error. The interface that carries these values is ArchConfigWithAttentionKVCache, which your config class must subclass and implement.

The four abstract properties your subclass must define are:

  • num_key_value_heads: int: total KV heads across all devices. Use num_key_value_heads from the Hugging Face config, not num_attention_heads. For grouped-query attention models the two differ; using the wrong field produces a cache sized for the wrong number of heads, silently.
  • head_dim: int: dimensionality of each attention head.
  • num_layers: int: number of hidden layers.
  • model_max_seq_len: int: maximum sequence length the model supports.

MAX calls get_kv_params() on your config to size the cache. The standard implementation uses a static construct_kv_params() helper that calls kv_cache_config.to_params() with those values extracted from the Hugging Face config. The following example shows the pattern all reference architectures follow:

model_config.py
@staticmethod
def construct_kv_params(
    huggingface_config: AutoConfig,
    pipeline_config: PipelineConfig,
    devices: list[DeviceRef],
    kv_cache_config: KVCacheConfig,
    cache_dtype: DType,
) -> KVCacheParams:
    return kv_cache_config.to_params(
        dtype=cache_dtype,
        n_kv_heads=huggingface_config.num_key_value_heads,
        head_dim=YourConfig.get_head_dim(huggingface_config),
        num_layers=YourConfig.get_num_layers(huggingface_config),
        devices=devices,
        data_parallel_degree=pipeline_config.model.data_parallel_degree,
    )

Store the result as a KVCacheParams in a kv_params field on your config class and return it from get_kv_params(). Then pass it to each AttentionWithRope layer at construction time:

model.py
self.attn = AttentionWithRope(
    kv_params=config.kv_params,
    # ... remaining arguments
)

Most architectures return a flat KVCacheParams. Architectures with multiple or hybrid KV caches (for example, separate sliding-window and full-attention caches) return a MultiKVCacheParams tree instead, which the cache manager consumes through the shared KVCacheParamInterface.

With kv_params correctly configured, MAX dispatches fused attention kernels (RoPE application, KV store, and flash attention) and manages paged memory allocation. See Serve custom model architectures for a complete implementation with KVCacheConfig wired end-to-end.

Override model components​

Overriding model components builds on config mapping and weight translation from the previous sections. The addition is a model.py that subclasses the reference model class and overrides only the components that differ. Hugging Face model implementations already follow this pattern: in the transformers library, many model classes subclass a family base and override only what changed. The same approach applies to MAX's nn.Module hierarchy.

Subclass the reference model​

model.py imports and subclasses the reference model class. Override only the class attributes or methods that differ from the reference. For example, Qwen2 is structurally identical to Llama 3 except that it uses attention bias on the projection layers. The entire model override is a single class attribute:

model.py
from max.pipelines.architectures.llama3.model import Llama3Model

class Qwen2Model(Llama3Model):
    attention_bias: bool = True

The subclass inherits the full graph-building and compilation logic from the reference architecture. Verify that no weight slot your subclass introduces is absent from the checkpoint, and that no checkpoint key is left without a corresponding slot.

Extend the compute graph​

Subclassing overrides attributes or methods within an existing graph structure. Graph extension changes the structure itself: how many normalization layers a block applies, whether attention and MLP share their input, which parameters are learnable. What those changes look like depends entirely on your model's architecture. max.nn provides the building blocks (attention layers, normalization, MLP, embeddings), and most structural differences can be expressed by composing from what's already there. The examples below show two cases where a model's structure differed from the reference and required graph-level changes.

Parallel decoder block: Llama 3 applies attention and MLP sequentially, each to its own residual stream. Some architectures instead apply both to the same normed input and add them together in one step:

model.py
# Sequential
x = x + self.self_attn(self.input_layernorm(x), ...)
x = x + self.mlp(self.post_attention_layernorm(x))

# Parallel decoder
normed = self.input_layernorm(x)
x = x + self.self_attn(normed, ...) + self.mlp(normed)

The parallel form requires one input_layernorm per block (not two) and a single residual add.

When building any nn.Module that holds a list of layers, store them in LayerList rather than a plain Python list. LayerList registers each layer in the nn.Module hierarchy under a numeric index (0, 1, …), so weight names resolve to layers.0.*, layers.1.*, and so on. A plain Python list doesn't register its items in the hierarchy: MAX won't find those weights during loading, and the layers run with uninitialized values.

Learnable LayerNorm: Some reference architectures expose norm type through a norm_method parameter, but not all norm types are equivalent. Llama 3's norm_method="layer_norm" creates a ConstantLayerNorm with fixed gamma=1 and beta=0. If your model has learnable scale weights in the checkpoint, passing this parameter loads them silently into empty slots and the model generates incoherent text. Instantiate the norm directly when you need learnable parameters:

model.py
from max.nn import LayerNorm

self.input_layernorm = LayerNorm(
    dims=config.hidden_size,
    eps=config.rms_norm_eps,
    use_bias=False,
)

LayerNorm(use_bias=False) has a learnable weight attribute that the weight adapter maps from the checkpoint. ConstantLayerNorm has none.

When you reuse subcomponents from a reference architecture (attention layers, normalization, MLP), verify that those subcomponents don't define parameters your model's checkpoint doesn't have. A subcomponent that creates weight slots with no matching checkpoint keys means those parameters stay at their initialized values. Call load_state_dict with strict=True to surface any such mismatch as an error. Check the subcomponent's state_dict() against your checkpoint's keys before reusing it.

Implement custom layers​

layers/ holds the individual building blocks that model.py composes into the full architecture. Each file in layers/ implements one component; model.py describes how they connect, what the residual stream looks like, and how they loop over layers.

When your architecture requires multiple custom components (custom attention, MoE routing, a non-standard transformer block), implement each in a separate file in layers/ rather than putting everything in model.py.

Expose the architecture​

The arch.py file ties the rest of the package together. It should include one or more SupportedArchitecture instances that point MAX at your config class, model class, weight adapters, and more:

arch.py
qwen2_arch = SupportedArchitecture(
    name="Qwen2ForCausalLM",
    task=PipelineTask.TEXT_GENERATION,
    pipeline_model=Qwen2Model,
    config=Qwen2Config,
    supported_encodings={"float32", "bfloat16"},
    weight_adapters={
        WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
    },
    # ... tokenizer, context_type, default_weights_format, ...
)

After you create a SupportedArchitecture instance, expose it through your package's __init__.py so MAX can register it when serving:

__init__.py
from .arch import qwen2_arch

ARCHITECTURES = [qwen2_arch]

Most packages expose a single architecture, but you can list more than one when:

  • There are multiple tasks for the same model. The registry allows entries that share a name as long as their task differs (for example, a text generation architecture and an embedding architecture for the same model).

  • There are multiple Hugging Face class names. A single package with related variants whose config.json covers a model family might register entries for several architectures[0] strings.

Next steps​

With your architecture implemented, see Serve custom model architectures to run a generation, package your implementation, and load it at serve time.

If your model requires an operation that max.nn doesn't provide and can't be composed from what's already there, see Custom ops for the interface, registration, and worked examples.

Once your architecture loads and generates text, verify its forward pass numerically against a PyTorch reference. See Logit comparison for the full workflow, or use the debug-model skill to run that per-layer comparison with an AI coding agent.

Was this page helpful?