Skip to main content

Python module

max.nn

Neural network modules for MAX.

Graph-based API:

from max.nn import Linear, AttentionWithRope
from max.nn.kv_cache import KVCacheParams

Eager tensor API:

from max.experimental.nn import Module, Linear, Embedding

Base classes

IdentityIdentity layer that passes through input unchanged.
Layer
LayerListStores a list of layers.
ModuleBase class for model components with weight management.
SequentialA sequential stack of layers where each layer is called by the outputs of the previous layer.
ShardableProtocol for objects that support sharding across multiple devices.
SignalsSignal buffers used for peer-to-peer communication in allreduce.

Linear layers

ColumnParallelLinearA Linear layer where the weight and bias are sharded onto multiple devices.
EmbeddingA lookup table for embedding integer indices into dense vectors.
GPTQLinearA Linear layer for GPTQ encoding.
LinearApplies a linear transformation to incoming data: y=xWT+by = xW^T + b.
LinearLoRAApplies a linear transformation and LoRA to input:
MLPSimple multi-layer perceptron composed of three Linear layers.
VocabParallelEmbeddingA lookup table for embedding integer indices into dense vectors.

Normalization

ConstantLayerNormLayer normalization block with constant gamma and beta values.
GroupNormGroup normalization block.
LayerNormLayer normalization block.
RMSNormComputes the Root Mean Square normalization on inputs.

LoRA

AttentionWithRopeAndLoRAInitializes the LoRA-enabled attention layer.
SupportsLoRABase class for supporting LoRA functionality in Modules

Rotary embeddings

DynamicRotaryEmbeddingApplies RoPE with dynamic scaling for long-context inference.
LinearScalingParamsScaling parameters for linear RoPE frequency scaling.
Llama3RopeScalingParamsScaling parameters for Llama3's frequency-based context extension.
Llama3RotaryEmbeddingApplies RoPE with Llama3-style frequency scaling for extended context lengths.
LongRoPERotaryEmbeddingApplies RoPE with LongRoPE scaling for Phi-3.5 models.
LongRoPEScalingParamsParameters for LongRoPE scaling as used in Phi-3.5 models.
RotaryEmbeddingApplies Rotary Position Embedding (RoPE) to transformer activations.
YarnRotaryEmbeddingApplies generic YaRN (Yet another RoPE eNhancement) Rotary Position Embedding.
YarnScalingParamsScaling parameters for YaRN (Yet another RoPE eNhancement) frequency interpolation.

Transformer

DistributedTransformerTransformer model consisting for TransformerBlock layers.
DistributedTransformerBlockStack of Attention, FeedForward, and RMSNorm layers.
ReturnHiddenStates
ReturnLogits
TransformerA transformer model consisting of TransformerBlock layers.
TransformerBlockStack of Attention, FeedForward, and RMSNorm layers.

Convolution

Conv1DA 1D convolution over an input signal composed of several input planes.
Conv2dA 2D convolution over an input signal composed of several input planes.
Conv3DA 3D convolution over an input signal composed of several input planes.
ConvTranspose1dA 1D transposed convolution operator over an input image composed of several input planes.
WeightNormConvTranspose1dA 1D transposed convolution operator over an input image composed of several input planes.

Mixture of experts

AllreduceLayer to perform allreduce operation with automatic implementation selection.
MoEImplementation of Mixture of Experts (MoE).
MoEGateGate module for MoE.
MoEQuantizedMixture of Experts with FP8 or NVFP4 quantization.

Sampling

MinPSamplerA min_p sampler.
RejectionSamplerRejection sampler for speculative decoding verification.
RejectionSamplerWithResidualsA simple rejection sampler.

Quantization

QuantConfigConfigures scaled quantization settings for a layer or model section.
InputScaleSpecSpecifies how input activations are scaled for scaled quantization.
ScaleGranularitySpecifies the granularity of the quantization scale factor.
ScaleOriginSpecifies whether the quantization scale is determined statically or dynamically.
WeightScaleSpecSpecifies how weights are scaled for scaled quantization.

Hooks

PrintHookHook that prints/saves layer tensor inputs and outputs.

Functions

build_max_lengths_tensorBuilds a [num_steps, 2] uint32 buffer of per-step maximum lengths.
clampClamps values in x to [min, max].
split_batchSplit a ragged input batch into data parallel batches.
split_batch_replicatedSplit a ragged token batch into data parallel batches.

Submodules