Python module
max.nn
Neural network modules for MAX.
Graph-based API:
from max.nn import Linear, AttentionWithRope
from max.nn.kv_cache import KVCacheParamsEager tensor API:
from max.experimental.nn import Module, Linear, EmbeddingBase classes
Identity | Identity layer that passes through input unchanged. |
|---|---|
Layer | |
LayerList | Stores a list of layers. |
Module | Base class for model components with weight management. |
Sequential | A sequential stack of layers where each layer is called by the outputs of the previous layer. |
Shardable | Protocol for objects that support sharding across multiple devices. |
Signals | Signal buffers used for peer-to-peer communication in allreduce. |
Linear layers
ColumnParallelLinear | A Linear layer where the weight and bias are sharded onto multiple devices. |
|---|---|
Embedding | A lookup table for embedding integer indices into dense vectors. |
GPTQLinear | A Linear layer for GPTQ encoding. |
Linear | Applies a linear transformation to incoming data: . |
LinearLoRA | Applies a linear transformation and LoRA to input: |
MLP | Simple multi-layer perceptron composed of three Linear layers. |
VocabParallelEmbedding | A lookup table for embedding integer indices into dense vectors. |
Normalization
ConstantLayerNorm | Layer normalization block with constant gamma and beta values. |
|---|---|
GroupNorm | Group normalization block. |
LayerNorm | Layer normalization block. |
RMSNorm | Computes the Root Mean Square normalization on inputs. |
LoRA
AttentionWithRopeAndLoRA | Initializes the LoRA-enabled attention layer. |
|---|---|
SupportsLoRA | Base class for supporting LoRA functionality in Modules |
Rotary embeddings
DynamicRotaryEmbedding | Applies RoPE with dynamic scaling for long-context inference. |
|---|---|
LinearScalingParams | Scaling parameters for linear RoPE frequency scaling. |
Llama3RopeScalingParams | Scaling parameters for Llama3's frequency-based context extension. |
Llama3RotaryEmbedding | Applies RoPE with Llama3-style frequency scaling for extended context lengths. |
LongRoPERotaryEmbedding | Applies RoPE with LongRoPE scaling for Phi-3.5 models. |
LongRoPEScalingParams | Parameters for LongRoPE scaling as used in Phi-3.5 models. |
RotaryEmbedding | Applies Rotary Position Embedding (RoPE) to transformer activations. |
YarnRotaryEmbedding | Applies generic YaRN (Yet another RoPE eNhancement) Rotary Position Embedding. |
YarnScalingParams | Scaling parameters for YaRN (Yet another RoPE eNhancement) frequency interpolation. |
Transformer
DistributedTransformer | Transformer model consisting for TransformerBlock layers. |
|---|---|
DistributedTransformerBlock | Stack of Attention, FeedForward, and RMSNorm layers. |
ReturnHiddenStates | |
ReturnLogits | |
Transformer | A transformer model consisting of TransformerBlock layers. |
TransformerBlock | Stack of Attention, FeedForward, and RMSNorm layers. |
Convolution
Conv1D | A 1D convolution over an input signal composed of several input planes. |
|---|---|
Conv2d | A 2D convolution over an input signal composed of several input planes. |
Conv3D | A 3D convolution over an input signal composed of several input planes. |
ConvTranspose1d | A 1D transposed convolution operator over an input image composed of several input planes. |
WeightNormConvTranspose1d | A 1D transposed convolution operator over an input image composed of several input planes. |
Mixture of experts
Allreduce | Layer to perform allreduce operation with automatic implementation selection. |
|---|---|
MoE | Implementation of Mixture of Experts (MoE). |
MoEGate | Gate module for MoE. |
MoEQuantized | Mixture of Experts with FP8 or NVFP4 quantization. |
Sampling
MinPSampler | A min_p sampler. |
|---|---|
RejectionSampler | Rejection sampler for speculative decoding verification. |
RejectionSamplerWithResiduals | A simple rejection sampler. |
Quantization
QuantConfig | Configures scaled quantization settings for a layer or model section. |
|---|---|
InputScaleSpec | Specifies how input activations are scaled for scaled quantization. |
ScaleGranularity | Specifies the granularity of the quantization scale factor. |
ScaleOrigin | Specifies whether the quantization scale is determined statically or dynamically. |
WeightScaleSpec | Specifies how weights are scaled for scaled quantization. |
Hooks
PrintHook | Hook that prints/saves layer tensor inputs and outputs. |
|---|
Functions
build_max_lengths_tensor | Builds a [num_steps, 2] uint32 buffer of per-step maximum lengths. |
|---|---|
clamp | Clamps values in x to [min, max]. |
split_batch | Split a ragged input batch into data parallel batches. |
split_batch_replicated | Split a ragged token batch into data parallel batches. |
Submodules
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!