# Modular
> Deploy fast and scalable GenAI inference
This file contains all documentation content in a single document following the llmstxt.org standard.
## Attention mask
An attention mask is a mechanism used in the [attention](attention.mdx) layers
of a [transformer](transformer.mdx) model to indicate which tokens the model
should ignore when computing attention scores.
For example, attention masks can prevent the model from attending to [padding
tokens](padding-tokens.mdx), which are added to make sequences in a batch the
same length and thus offer no information for attention.
Another common mask is a "causal mask" (or "look-ahead mask"), which prevents
the [self-attention](self-attention) layer from looking at future tokens when
predicting a new token, ensuring that it attends only to previous tokens in the
sequence. Although it sounds absurd that it would even try to look at future
tokens (because it's generating tokens one at a time, in order), the
self-attention is designed for more general-purpose attention scoring. In its
most basic form, self-attention is agnostic to token order—it looks at all
tokens in the sequence equally, based on their embeddings, and calculates
scores by looking both backward and ahead in the sequence. (For example,
self-attention is used during [context encoding](context-encoding.mdx) to
establish an understanding of the input text.) So instead of creating a
different kind of attention mechanism for autoregressive inference, the causal
mask instructs the self-attention layer to simply ignore all future tokens and
only look backward when generating scores that help predict the next token.
---
## Attention
A mechanism used in AI models such as [transformers](transformer.mdx) that
enables the model to selectively focus on different parts of the input sequence
when making predictions.
Unlike traditional model architectures that process all input data with equal
importance, models with attention assign different importance levels to
different tokens (such as words or pixels). This allows the model to better
understand the complete meaning of the input, especially when an accurate
meaning depends on relationships between tokens that are far apart (such as
between words that occur far apart in a sentence).
Attention is crucial for large language models (LLMs) so they can capture
long-range dependencies and contextual relationships in the given text. It
allows LLMs to handle complex and nuanced language, enabling them to generate
coherent and contextually relevant output even when the input text includes
nuanced references to other parts of the text.
Attention was introduced and refined in the papers [Neural Machine Translation
by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)
(Bahdanau et al., 2014) and [Effective Approaches to Attention-based Neural
Machine Translation ](https://arxiv.org/abs/1508.04025) (Luong et al., 2015).
The most well-known form of attention is [self-attention](self-attention.mdx),
in which each token gets its own attention score for every other token (each
token "attends to" all other tokens), in order to determine the relative
importance of each other token in that context.
## Implementation details
The classic attention operation follows this general structure:
It consists of the following operations (`bmm` is short for batched matrix
multiplication):
* `bmm`: `Q x Transpose(K)`
where `Q`, `K` both have shape `[batchSize, numHeads, S, d]`
and `Q x K^t` has the shape `[batchSize, numHeads, S, S]`
* `softmax`
* `bmm`: `softmax(Q x K^t) x V`
where V has the shape `[batchSize, numHeads, S, d]`
`S` denotes the sequence length. Depending on the model, it can be as large as
`O(10^3) - O(10^4)`. `d` is the size per head in multi-head attention. It’s
usually a power of 2 like 64, 128, etc, and smaller than `S`.
A limitation of the classic implementation is that it materializes an
intermediate matrix of shape `[batchSize, numHeads, S, S]`. This introduces
`O(S^2)` memory allocation and traffic.
---
## Autoregression
Autoregression is a process by which an AI model iteratively predicts future
values based on previous values in a sequence, using its own output as input to
itself. Because each prediction depends on prior context, the process is
sequential, which limits parallelization.
Autoregression is a standard procedure in [transformer](transformer.mdx) models
such as large language models (LLMs) and other models that perform time-series
forecasting. This autoregressive process explains why AI chat bots like ChatGPT
and Gemini stream the output one word at a time—although they sometimes run so
fast that they appear to produce more than one word at a time.
---
## Batching
Batching is the process of combining multiple inference requests into a single
forward pass through the model, thus executing multiple requests simultaneously
and improving computational efficiency. To account for requests with varying
sequence lengths, it's common to add techniques such as
[padding](padding-tokens.mdx) (to standardize lengths) or [ragged
tensors](ragged-tensors.mdx) (to handle variable lengths directly).
Batch sizes can be either static or dynamic. Whereas static batching uses a
fixed batch size and thus waits until the system receives a specific number of
inference requests before sending them into the model, dynamic batching uses a
flexible batch size. For example, dynamic batching may send a batch of requests
to the model as soon as the batch either reaches a certain number of requests
(batch size limit) or it reaches a timeout threshold.
Dynamic batching can get a lot more complicated than that with additional
tricks that keep GPUs busy instead of waiting for one batch to finish before
starting another. One such strategy for large language models (LLMs) is
[continuous batching](continuous-batching.mdx).
---
## Context encoding
Context encoding (also known as "prefill") is the first phase in a [transformer
model](transformer.mdx) that converts input data into a cached numerical
representation ([KV cache](kv-cache.mdx)) and predicts the first token. It
occurs after the input has already been [tokenized](tokenization.mdx)
(preprocessed).
Context encoding is then followed by the [autoregressive](autoregression.mdx)
token generation phase, which produces one token at a time. If it weren't for
the KV cache built during context encoding, the model would have to recalculate
the [self-attention](self-attention.mdx) score for each token in the original
input, every time it starts to predict a new token.
Context encoding is usually the most computationally expensive phase in an LLM,
because it must calculate attention scores for every token in the input
sequence. Although this process may be parallelized across thousands of GPU
threads (because each token can be processed separately), it is still a
significant latency factor for time-to-first-token (TTFT). The model can
usually produce subsequent tokens much faster than the first one because each
round of token generation needs to calculate an attention score for only one
token (the new one).
---
## Continuous batching
Continuous batching is a [batching](batching.mdx) technique that can
continuously dispatch inference requests to the GPU for [token
generation](token-generation.mdx) and dramatically improve GPU utilization.
Continuous batching can start executing a new batch even before the previous
batch finishes its pass through the model, because this batching technique
schedules new processing at the "token level."
That is, because large language models (LLMs) generate responses one token at a
time, there is a repeated cycle during inference (the token generation phase)
in which a new batch can jump in to utilize the GPU, even before a previous
batch finishes its pass through the model. That's what it means to operate at
the "token level"—the batch scheduler focuses on keeping the GPU busy with
token generation at all times, instead of waiting for the previous batch to
finish its complete forward pass.
This is sometimes called "in-flight batching" in cases where context
encoding and token generation requests are combined into the same batch.
---
## Embedding
An embedding (also known as a "vector embedding") is a numerical representation
of information in a high-dimensional vector space. For example, a token
embedding (or word embedding) encodes the meaning of words for use in large
language models (LLMs).
Because artificial neural networks (AI models) are a sequence of mathematical
operations, they require numerical structures as input. Vector embeddings are
numerical structures that provide a way to express a wide range of complex
concepts. They can be used to capture information about all sorts of things,
including words, groups of words, sounds, images, and more.
For example, [tokenizing](tokenization.mdx) a word like "bank" into a simple
number can't encode the different meanings in "bank loan" and "river bank." By
converting the token into a high-dimensional vector, we can encode (or "embed")
a variety of word meanings that help the model understand word relationships
using a notion of closeness along various vector dimensions (expressed through
[euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)). In
this way, when a model encounters the embedding for the word "bank," it can
recognize the relationship it has with nearby words such as "loan" or "river,"
based on the closeness they each have to each other on different vector
dimensions (perhaps a "finance" dimension vs a "geography" dimension that were
learned during training).
Although word embeddings are a type of static embedding that encode the meaning
of individual words as input to an LLM, an LLM also builds its own embeddings
that are hidden inside the model. For example, as an LLM tries to understand
the relationship between each word from an input sequence, it compresses more
information into each token embedding based on the attention scores computed in
the [self-attention layer](self-attention.mdx).
:::note Embedding models
Whereas the token embeddings described above use a vector space to represent
the meaning of individual tokens, the output from an embedding model uses a
vector space to represent the meaning of the input data (a document) as a
whole. In this way, an embedding model allows you to programmatically search
and compare different documents by analyzing their corresponding embeddings,
which can reveal nuanced meaning and semantics far beyond what a pure text
comparison can achieve.
:::
---
## Flash attention
Flash attention is an optimization technique to compute attention blocks in
[transformer](transformer.mdx) models. Traditional [attention](attention.mdx)
requires storing large intermediate activation tensors, leading to high memory
overhead that slows execution because it requires frequent memory transfers
between high-bandwidth memory (HBM) and faster SRAM on the GPU.
Flash attention improves performance and reduces the memory footprint for
attention layers. It reorders computations with techniques such as tiling to
compute attention scores in blocks, and it keeps only small chunks of
activations in the faster on-chip SRAM. This allows the model to process much
longer sequences without running into memory limitations.
By improving the efficiency of attention layers, flash attention enables LLMs
to handle much longer contexts, improving their ability to understand and
generate complex text. It's particularly beneficial for:
* Large language models with long context windows
* Vision transformers processing high-resolution images
* Multi-modal models with large attention matrices
* Fine-tuning large models on limited GPU memory
## Implementation details
Flash attention optimizes the classic [attention](attention.mdx) mechanism by:
1. **Tiling the computation**: Breaking the `Q`, `K`, and `V` matrices into
smaller blocks that fit in GPU shared memory, which is much faster than
global memory.
2. **Fusing operations**: Combining softmax normalization with matrix
multiplication for each tile into a single kernel.
These help maximize the locality and reduce DRAM (global memory) traffic.
To see an implementation of
[FlashAttention-2](https://arxiv.org/abs/2307.08691) as a fused operation, see
[`fused_attention.mojo` on
GitHub](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/fused_attention.mojo).
---
## AI terms
import MDXListing from '@site/src/components/Listing/MDXListing';
export const terms = [
'*.mdx'
]
---
## KV cache
KV (key-value) cache is a memory structure used in
[transformer](transformer.mdx) models to store key-value tensors output from
[self-attention](self-attention.mdx) layers. The KV cache speeds up inference
for transformer models such as large language models (LLMs) by avoiding the
need to recompute the self-attention scores for all previous tokens in a
sequence.
For example, suppose an LLM is trying to complete the sentence, "The quick
brown fox..." After the model predicts "jumps" and then begins to predict the
next token, the model must know the attention score for every token in the
sequence so far (including the one it just predicted). That is, for each step
in the [autoregression](autoregression.mdx) cycle, it must process the entire
sequence thus far:
1. "The quick brown fox..."
2. "The quick brown fox jumps..."
3. "The quick brown fox jumps over..."
And so on.
By storing the already-calculated attention scores for previous tokens in KV
cache, the model simply reads the KV cache at each step, instead of recomputing
those scores all over again. Once the model predicts the next token and
calculates its self-attention, it adds it to the KV cache.
As the sequence length grows during inference (as more words are generated),
the KV cache becomes the dominant factor in a model's memory usage. The
sequence length is always limited by the model's total context window length,
which varies across models and can usually be configured.
---
## Padding tokens
Padding tokens are extra tokens (usually zeros or special tokens) that are
added to the input for a model so that the input matches the model's fixed
input length or to ensure that all sequences in a [batch](batching.mdx) have
the same length.
In [transformer](transformer.mdx) models, padding tokens have been mostly
replaced with [ragged tensors](ragged-tensors.mdx).
---
## PagedAttention
PagedAttention is a memory management technique designed to improve GPU memory
utilization during large language model (LLM) serving. Inspired by classical
virtual memory and paging methods used in operating systems, PagedAttention
divides the [KV cache](kv-cache.mdx) into fixed-size blocks, which are not
necessarily stored contiguously in memory. This approach enables more efficient
handling of dynamic states in LLMs, allowing the model to manage large context
sizes while optimizing memory usage, as described in the 2023 paper [Efficient
Memory Management for Large Language Model Serving with
PagedAttention](https://arxiv.org/abs/2309.06180) (Kwon, et al., 2023).
Also written as "paged attention."
---
## Prefill
Prefill is the first phase of an AI model's forward pass in which the model
processes the input and initializes a cache to accelerate predictions.
Different model architectures may have their own version of a prefill, but it's
primarily associated with large language models (LLMs), in which case it's also
called [context encoding](context-encoding.mdx).
---
## Ragged tensors
Ragged tensors is a method for batching multiple requests with differing
sequence lengths without the need for [padding tokens](padding-tokens.mdx).
Ragged tensors allow sequences of variable lengths to be processed together
efficiently by storing them in a compact, non-uniform format.
Also sometimes referred to as "packed tensors."
---
## Self-attention
Self-attention is a mechanism in a [transformer](transformer.mdx) model that
calculates the importance of different tokens (such as words) in a sequence,
relative to each other. Each token is said to "attend to" all other tokens in
the sequence by assigning an "attention score" to each one.
In a large language model (LLM), self-attention allows the model to build an
understanding of the whole text by evaluating how each word is relevant to all
other words in the text, no matter how far they are from each other.
The attention scores are computed using query, key, and value (QKV) vectors
that pertain to each token:
- The **query** is a vector that expresses what information a token is
*looking for* among all the other tokens (like a search query).
- The **key** is a vector that describes the information a token *offers* to
other tokens (like an answer to a token's query).
- The **value** is a vector that provides the **contextually-relevant
information** about this token.
After calculating attention scores by comparing the **query** and **key**
vectors between tokens, self-attention uses the scores to apply weighted
information from each token's **value** into a new [embedding](embedding.mdx)
for each token. Thus, self-attention outputs a new token embedding for each
token that carries information about its relationship with the other tokens in
the sequence.
The model also saves the calculated keys and values into the [KV
cache](kv-cache.mdx) to avoid redundant recompute for the same tokens during
the next [autoregression](autoregression.mdx) cycle.
---
## Tokenization
Tokenization is the process of dividing the input for an AI model into discrete
units that have numerical IDs called tokens. Depending on what the input is
(such as text, audio, or an image) the tokens might be based on different words
or subwords in text, or different slices/blocks of pixels in images.
For example, consider the sentence, "The cat sat on the mat." A word-level
tokenization might split this sentence into the following words: "The," "cat,"
"sat," "on," "the," "mat." Then it replaces each word with a token (a number).
The token "vocabulary"—the mapping of words to numbers—is predetermined and may
vary from model to model.
But tokenizers in large language models (LLMs) are much more sophisticated than
that. Among other things, they also tokenize punctuations (or combinations of
words and punctuations) and break words into subwords that allow them to
tokenize words they've never seen before.
Because LLMs are trained on these tokens, they don't actually understand words
and letters the way we do. They can only recognize and generate information
based on the token vocabulary that they were trained upon. (Popular LLMs have a
token vocabulary with over 100,000 tokens.)
---
## Transformer
A transformer is a neural network architecture designed to perform complex
tasks with sequential data (such as text, speech, and images) in a manner that
can be efficiently parallelized on GPUs or other accelerator hardware. This
makes them highly effective for natural language processing and other
generative AI (GenAI) applications.
The transformer model architecture was first introduced in the paper [Attention
Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani, et al., 2017).
This design emphasizes the use of [self-attention](self-attention.mdx)
mechanisms instead of recurrent structures like recurrent neural networks (RNNs) or
long short-term memory networks (LSTMs), which is what
allows for the processing to be parallelized across separate compute cores
instead of requiring the model to generate predictions synchronously. This
design is currently the foundation for all major large language models (LLMs)
such as GPT, Llama, Gemini, DeepSeek, and more.
---
## Block index
In GPU programming, a block index uniquely identifies a subset of
[threads](thread) that execute a [kernel](kernel.mdx) function on the GPU.
Threads are grouped into units called [blocks](thread-block.mdx), and multiple
blocks together form a larger structure known as a [grid](grid.mdx).
Each block within the grid is assigned a unique block index, which can be
represented across one, two, or three dimensions. This allows for flexible
organization of threads to match the structure of the problem being solved.
Within each block, individual threads have their own [thread
index](thread-index.mdx), which, together with the block index, determines which
part of the problem each thread should work on. This hierarchical structure of
grids, blocks, and threads enables efficient workload distribution across the
many processing cores of the GPU, maximizing parallel performance.
Because a programmer can arrange thread blocks within a grid across one, two,
or three dimensions, a block index is a 3-element vector of x, y, and z
coordinates. For 2-dimensional arrangements, the z coordinate of all block
indices is 0, and for 1-dimensional arrangements, both the y and z coordinates
of all block indices are 0.
---
## Grid
A grid is the top-level organizational structure of the threads executing a
[kernel](kernel.mdx) function on a GPU. A grid consists of multiple [thread
blocks](thread-block.mdx) (also known as *workgroups* on AMD GPUs), which are
further divided into individual [threads](thread.mdx) (or *work units* on AMD
GPUs) that execute the kernel function concurrently.
The division of a grid into thread blocks serves multiple crucial purposes:
- First, it breaks down the overall workload—managed by the grid—into
smaller, more manageable portions that can be processed independently. This
division allows for better resource utilization and scheduling flexibility
across multiple [streaming multiprocessors](streaming-multiprocessor.mdx)
(SMs) in the GPU (or *compute units* on AMD GPUs).
- Second, thread blocks provide a scope for threads to collaborate through
shared memory and synchronization primitives, enabling efficient parallel
algorithms and data sharing patterns.
- Finally, thread blocks help with scalability by allowing the same program to
run efficiently across different GPU architectures, as the hardware can
automatically distribute blocks based on available resources.
The programmer specifies the number of thread blocks in a grid and how they are
arranged across one, two, or three dimensions. Typically, the programmer
determines the dimensions of the grid based on the dimensionality of the data to
process. For example, a programmer might choose a 1-dimensional grid for
processing large vectors, a 2-dimensional grid for processing matrices, and a
3-dimensional grid for processing the frames of a video. Each block within the
grid is assigned a unique [block index](block-index.mdx) that determines its
position within the grid.
Similarly, the programmer also specifies the number of threads per thread block
and how they are arranged across one, two, or three dimensions. Each thread
within a block is assigned a unique [thread index](thread-index.mdx) that
determines its position within the block. The combination of block index and
thread index uniquely identify the position of a thread within the overall grid.
---
## GPU terms
import MDXListing from '@site/src/components/Listing/MDXListing';
export const terms = [
'*.mdx'
]
---
## Kernel
A kernel is a function that runs on a GPU, executing computations in parallel
across a large number of [threads](thread.mdx). Kernels are a fundamental
part of general-purpose GPU (GPGPU) programming and are designed to process
large datasets efficiently by performing the same operation simultaneously on
multiple data elements.
---
## GPU memory
GPU memory consists of both on-chip memory and external dynamic random-access
memory (DRAM), often referred to as *device memory* (in contrast to the *host
memory* used by the CPU).
On-chip memory includes:
- A register file for each [streaming
multiprocessor](streaming-multiprocessor.mdx) (SM), containing the
[registers](register.mdx) used by threads executing on the SMs cores
- An L1 cache for each SM to cache reads from global memory
- Shared memory for each SM, containing data explicitly shared between the
threads of a given [thread block](thread-block.mdx) executing on the SM
- A read-only constant cache for each SM, which caches data read from the
constant memory space in global memory
- An L2 cache shared by all SMs that is used to cache accesses to local or
global memory, including temporary register spills
Device memory includes:
- Global memory, which contains data accessible to all threads
- Constant memory, which contains data explicitly identified as read-only by the
programmer, and which is accessible to all threads
- Local memory, which contains data private to an individual thread, such as
statically allocated arrays, spilled registers, and other elements of the
thread's call stack
Data in global memory persists until explicitly freed, even across
[kernel](kernel.mdx) functions. This means that one kernel can write data to
global memory and then a subsequent kernel can read that data.
---
## Occupancy
In GPU programming, occupancy is a measure of the efficiency of the GPU's
compute resources. It is defined as the ratio of the number of active
[warps](warp.mdx) to the maximum number of warps that can be active on a given
[streaming multiprocessor](streaming-multiprocessor.mdx) (SM) at any one time.
Higher occupancy can improve parallel execution and hide memory latency, but
increasing occupancy does not always boost performance, as factors like memory
bandwidth and instruction dependencies may create bottlenecks. The optimal
occupancy level depends on the workload and GPU architecture.
---
## Register
A GPU register is the fastest form of storage within a [streaming
multiprocessor](streaming-multiprocessor.mdx) (SM). Registers store integer and
floating point values used frequently by a [thread](thread.mdx), reducing
reliance on slower [memory](memory.mdx) types (shared, global, or local
memory).
Registers are located within an SM in what is referred to as a *register file*.
The number of registers depends on the GPU architecture, but modern GPUs support
thousands of registers per SM.
For each thread that it executes, the SM allocates a set of registers for the
private use of that thread. The registers are associated with that thread
throughout its lifetime, even if the thread is not currently executing on the
SM's cores (for example, if it is blocked waiting for data from memory). A
thread can't access registers assigned to a different thread, preventing data
conflicts between threads. If the execution of a [kernel](kernel.mdx) function
by a thread requires more registers than available, the compiler arranges to
spill some register data to the thread's local [memory](memory.mdx). Because
local memory access is slower than register access, programmers should try to
design their kernels to avoid or limit the amount of spill.
---
## Streaming multiprocessor
The basic building block of a GPU is called a *streaming multiprocessor* (SM)
on NVIDIA GPUs or a *compute unit* (CU) on AMD GPUs (they're the same idea and
we'll refer to them both as SM). SMs sit between the high-level GPU control
logic and the individual execution units, acting as self-contained processing
factories that can operate independently and in parallel.
Multiple SMs are arranged on a single GPU chip, with each SM capable of handling
multiple workloads simultaneously. The GPU's global scheduler assigns work to
individual SMs, while the memory controller manages data flow between the SMs
and various [memory](memory.mdx) hierarchies (global memory, L2 cache, etc.).
The number of SMs in a GPU can vary significantly based on the model and
intended use case, from a handful in entry-level GPUs to dozens or even hundreds
in high-end professional cards. This scalable architecture enables GPUs to
maintain excellent performance across different workload sizes and types.
Each SM contains several essential components:
- **CUDA Cores (NVIDIA)/Stream Processors (AMD):** These are the basic
arithmetic logic units (ALUs) that perform integer and floating-point
calculations. A single SM can contain dozens or hundreds of these cores.
- **Tensor Cores (NVIDIA)/Matrix Cores (AMD):** Specialized units optimized for
matrix multiplication and convolution operations.
- **Special Function Units (SFUs):** Handle complex mathematical operations like
trigonometry, square roots, and exponential functions.
- **[Register](register.mdx) Files:** Ultra-fast storage that holds intermediate
results and thread-specific data. Modern SMs can have hundreds of kilobytes of
register space shared among active [threads](thread.mdx).
- **Shared Memory/L1 Cache:** A programmable, low-latency memory space that
enables data sharing between threads. This memory is typically configurable
between shared memory and L1 cache functions.
- **Load/Store Units:** Manage data movement between different memory spaces,
handling memory access requests from threads.
---
## Thread block
In GPU programming, a thread block (also known as *workgroup* on AMD GPUs) is a
subset of threads within a [grid](grid.mdx), which is the top-level
organizational structure of the [threads](thread.mdx) executing a
[kernel](kernel.mdx) function. As the primary building block for workload
distribution, thread blocks serve multiple crucial purposes:
- First, they break down the overall workload — managed by the grid — of a
kernel function into smaller, more manageable portions that can be processed
independently. This division allows for better resource utilization and
scheduling flexibility across multiple [streaming
multiprocessors](streaming-multiprocessor.mdx) (SMs) in the GPU.
- Second, thread blocks provide a scope for threads to collaborate through
shared memory and synchronization primitives, enabling efficient parallel
algorithms and data sharing patterns.
- Finally, thread blocks help with scalability by allowing the same program to
run efficiently across different GPU architectures, as the hardware can
automatically distribute blocks based on available resources.
The programmer specifies the number of thread blocks in a grid and how they are
arranged across one, two, or three dimensions. Each block within the grid is
assigned a unique [block index](block-index.mdx) that determines its position
within the grid. Similarly, the programmer also specifies the number of threads
per thread block and how they are arranged across one, two, or three dimensions.
Each thread within a block is assigned a unique [thread index](thread-index.mdx)
that determines its position within the block.
The GPU assigns each thread block within the grid to a streaming multiprocessor
(SM) for execution. The SM groups the threads within a block into fixed-size
subsets called [warps](warp.mdx), consisting of either 32 or 64 threads each
depending on the particular GPU architecture. The SM's warp scheduler manages
the execution of warps on the SM's cores.
Threads within a block can share data through [shared memory](memory.mdx)
and synchronize using built-in mechanisms, but they cannot directly communicate
with threads in other blocks.
---
## Thread index
In GPU programming, a thread index uniquely identifies the position of a
[thread](thread.mdx) within a particular [thread block](thread-block.mdx)
executing a [kernel](kernel.mdx) function on the GPU. A thread block is a subset
of threads in a [grid](grid.mdx), which is the top-level organizational
structure of the threads executing a kernel function. Each block within the grid
is also assigned a unique block index, which identifies the block's position
within the grid. The combination of block index and thread index uniquely
identifies the thread's overall position within the grid, and is used to
determine which part of the problem each thread should work on.
Because a programmer can arrange threads within a thread block across one, two,
or three dimensions, a thread index is a 3-element vector of x, y, and z
coordinates. For 2-dimensional arrangements, the z coordinate of all thread
indices is 0, and for 1-dimensional arrangements, both the y and z coordinates
of all thread indices are 0.
---
## Thread
In GPU programming, a thread (also known as a *work unit* on AMD GPUs) is the
smallest unit of execution within a [kernel](kernel.mdx) function. Threads are
grouped into [thread blocks](thread-block.mdx) (or *workgroups* on AMD GPUs),
which are further organized into a [grid](grid.mdx).
The programmer specifies the number of thread blocks in a grid and how they are
arranged across one, two, or three dimensions. Each block within the grid is
assigned a unique [block index](block-index.mdx) that determines its position
within the grid. Similarly, the programmer also specifies the number of threads
per thread block and how they are arranged across one, two, or three dimensions.
Each thread within a block is assigned a unique [thread index](thread-index.mdx)
that determines its position within the block.
The GPU assigns each thread block within the grid to a [streaming
multiprocessor](streaming-multiprocessor.mdx) (SM) for execution. The SM groups
the threads within a block into fixed-size subsets called [warps](warp.mdx),
consisting of either 32 or 64 threads each depending on the particular GPU
architecture. The SM's warp scheduler manages the execution of warps on the SM's
cores.
The SM allocates a set of [registers](register.mdx) for each thread to store
and process values private to that thread. The registers are associated with
that thread throughout its lifetime, even if the thread is not currently
executing on the SM's cores (for example, if it is blocked waiting for data from
memory). Each thread also has access to [local memory](memory.mdx) to store
statically allocated arrays, spilled registers, and other elements of the
thread's call stack.
Threads within a block can share data through shared memory and synchronize
using built-in mechanisms, but they cannot directly communicate with threads in
other blocks.
---
## Warp
In GPU programming, a warp (also known as a *wavefront* on AMD GPUs) is a subset
of [threads](thread.mdx) from a [thread block](thread-block.mdx) that execute
together in lockstep. When a GPU assigns a thread block to execute on a
[streaming multiprocessor](streaming-multiprocessor.mdx) (SM), the SM divides
the thread block into warps of 32 or 64 threads, with the exact size depending
on the GPU architecture.
If a thread block contains a number of threads not evenly divisible by the warp
size, the SM creates a partially filled final warp that still consumes the full
warp's resources. For example, if a thread block has 100 threads and the warp
size is 32, the SM creates:
- 3 full warps of 32 threads each (96 threads total)
- 1 partial warp with only 4 active threads but still occupying a full warp's
worth of resources (32 thread slots)
The SM effectively disables the unused thread slots in partial warps, but these
slots still consume hardware resources. For this reason, developers generally
should make thread block sizes a multiple of the warp size to optimize resource
usage.
Each thread in a warp executes the same instruction at the same time on
different data, following the single instruction, multiple threads (SIMT)
execution model. If threads within a warp take different execution paths (called
*warp divergence*), the warp serially executes each branch path taken, disabling
threads that are not on that path. This means that optimal performance is
achieved when all threads in a warp follow the same execution path.
An SM can actively manage multiple warps from different thread blocks
simultaneously, helping keep execution units busy. For example, the warp
scheduler can quickly switch to another ready warp if the current warp's threads
must wait for memory access.
Warps deliver several key performance advantages:
- The hardware needs to manage only warps instead of individual threads,
reducing scheduling overhead
- Threads in a warp can access contiguous memory locations efficiently through
memory coalescing
- The hardware automatically synchronizes threads within a warp, eliminating the
need for explicit synchronization
- The warp scheduler can hide memory latency by switching between warps,
maximizing compute resource utilization
---
## Glossary
import MDXListing from '@site/src/components/Listing/MDXListing';
Explanations for some terms and concepts you'll encounter in the Modular docs.
## GPU terms
export const gpuTerms = [
'gpu/*.mdx'
]
## AI terms
export const aiTerms = [
'ai/*.mdx'
]
---
## Modular Documentation
import Homepage, { GetStartedButton } from "@site/src/components/Homepage";
import CodeNote from "@site/src/components/Homepage/CodeNote";
import { ArrowTransfer } from "@site/src/shared/Svgs/ArrowTransfer";
import { ArrowCloud } from "@site/src/shared/Svgs/ArrowCloud";
import { DesktopCode } from "@site/src/shared/Svgs/DesktopCode";
import { AIChip } from "@site/src/shared/Svgs/AIChip";
import { RecipesIcon } from "@site/src/shared/Svgs/RecipesIcon";
import { OpenBook } from "@site/src/shared/Svgs/OpenBook";
import { PuzzleIcon } from "@site/src/shared/Svgs/PuzzleIcon";
## Modular Documentation
The Modular Platform accelerates AI inference and abstracts hardware
complexity. Using our Docker container, you can deploy a GenAI model from
Hugging Face with an OpenAI-compatible endpoint on a wide range of hardware.
And if you need to customize the model or tune a GPU kernel, Modular
provides a depth of model extensibility and GPU programmability that you
won't find anywhere else.
```python title="python"
from openai import OpenAI
client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="EMPTY")
completion = client.chat.completions.create(
model="google/gemma-3-27b-it",
messages=[
{"role": "user", "content": "Who won the world series in 2020?"}
],
)
print(completion.choices[0].message.content)
```
export const sectionCards = [
{
title: "Serving",
description:
"Modular’s serving library is compatible with OpenAI APIs, so you can own your endpoint with minimal client-side code changes.",
to: "/max/container",
icon: ,
},
{
title: "Deploying",
description:
"Deploy your GenAI models to the cloud and scale your deployments across heterogenous GPU clusters.",
to: "/mammoth/",
icon: ,
},
{
title: "Developing",
description:
"The Modular platform provides full extensibility, so you can write custom ops, hardware-agnostic GPU kernels, and more.",
to: "/max/develop/",
icon: ,
},
{
title: "Programming with Mojo🔥",
description:
"Mojo is a Python-style programming language that allows you to write code for both CPUs and GPUs. ",
to: "/mojo/manual/",
icon: ,
},
];
export const learningToolCards = [
{
title: "Agentic Cookbook",
description:
"Turn-key applications that use GenAI models with the Modular platform.",
href: "https://modul.ar/cookbook",
icon: ,
},
{
title: "GPU Puzzles",
description: "A hands-on guide to mastering GPU programming with Mojo.",
href: "https://builds.modular.com/puzzles",
icon: ,
},
{
title: "Build an LLM with MAX",
description:
"Learn to build an LLM from scratch with MAX.",
href: "https://llm.modular.com/",
icon: ,
},
];
---
## Disaggregated inference
import ContactSection from '@site/src/components/ContactSection';
Disaggregated inference is a serving architecture pattern designed for large
language models (LLMs), particularly decoder-only transformer models like those
in the LLaMA or GPT model families. In decoder-only transformers, the process
of generating model output is divided into two distinct phases: prefill and
decode.
With disaggregated inference, these phases are executed on different hardware
resources. You might see this technique referred to by several names, including
disaggregated inference, disaggregated prefill, or disaggregated serving. All
of these describe the same core idea: separating the model's inference phases
and providing each phase with dedicated resources optimized to improve
performance and scalability.
:::note
Mammoth is the technology behind advanced features like disaggregated inference,
routing, and scaling in Modular's Dedicated Endpoint and Enterprise
[editions](https://www.modular.com/pricing).
[Get in touch](https://www.modular.com/request-demo) to learn how Mammoth
enables more efficient, large-scale model serving.
:::
## When to use disaggregated inference
Disaggregated inference is particularly valuable if your priority is minimizing
latency. Since the prefill stage is compute-intensive and the decode stage is
more memory-bound, isolating the two stages and allocating them to different
GPUs or GPU nodes reduces resource contention and helps achieve both faster
time-to-first-token and smoother token streaming.
Because disaggregated inference gives you separate levers to manage the prefill
and decode phases independently, it is especially effective for improving tail
latency, such as P95, which measures how long it takes to complete 95% of
requests. By optimizing tail latency, you reduce delays for the slowest requests
and can improve overall responsiveness.
Disaggregation itself doesn't directly increase throughput, but it enables more
granular parallelism strategies and resource allocation, which can increase
processing capacity. This flexibility allows you to optimize each phase
appropriately and scale prefill and decode nodes independently as needed,
improving GPU utilization and overall efficiency without over-provisioning
capacity just to handle peak workloads.
Additionally, disaggregated inference offers flexibility in heterogeneous or
resource-constrained environments. You can match each phase with hardware that
suits its specific demands.
## How disaggregated inference works
LLM inference involves two distinct phases known as prefill and decode, each
with unique performance characteristics that affect how systems should allocate
and optimize resources.
A simplified illustration of the separate prefill and
decode nodes used in a disaggregated inference serving architecture.
Prefill, also known as context encoding, is the initial phase where the model
processes the entire input prompt. During this phase, the model performs a full
forward pass to initialize its key-value (KV) cache and predict the first
output token. This cache stores the intermediate attention states necessary for
generating subsequent tokens. The prefill phase is compute-intensive,
especially in the case of long user prompts, as it involves large-scale matrix
operations that demand high floating-point throughput. The metric associated
with this phase is often referred to as Time-to-First-Token (TTFT), indicating
the duration from receiving the input prompt to producing the first output
token.
Following prefill, the model enters the decode phase, or token generation. In
this phase, the model generates output tokens one at a time, using the KV cache
initialized during prefill. By leveraging this cache, the model can quickly
access previously computed information without reprocessing the full input each
time. As a result, the decoding phase is less compute-intensive per token but
becomes memory-bound, relying heavily on efficient access to the cached data.
The key performance metric here is Inter-Token Latency (ITL), which measures
the time taken to generate each subsequent token after the first.
Disaggregated inference involves separating these two phases onto different
GPUs or GPU nodes. By doing so, each phase can be optimized independently.
Prefill workloads can be routed to hardware with high compute throughput to
handle the intensive matrix operations required to process long input prompts.
Meanwhile, decode workloads can be assigned to hardware with fast memory
access, which are better suited for the sequential, cache-dependent nature of
token generation. This separation reduces contention between compute-bound and
memory-bound tasks, improves resource utilization, and allows for more scalable
and predictable inference performance.
## Become a design partner
If you're exploring disaggregated inference for your deployments, start by
analyzing your workload to spot any imbalances between prompt processing and
token generation. Check whether your GPUs are underutilized during either
phase. If you're encountering these challenges, feel free to reach out to talk
to an AI expert.
---
## Scale your GenAI deployments
import ContactSection from '@site/src/components/ContactSection';
The [Modular Platform](/max/intro) provides dedicated endpoints and
enterprise-grade scaling for inference workloads. This scaling logic is powered
by Mammoth, a Kubernetes-native distributed AI serving tool that makes it
easier to run and manage LLMs at scale using MAX as a backend for optimal model
performance. It's designed to maximize hardware efficiency with minimal
configuration, even when running multiple models across thousands of nodes.
Figure 1. A simplified diagram of how the Modular Platform
scales your GenAI deployment.
The Mammoth control plane automatically selects the best available hardware to
meet performance targets when deploying a model and supports both manual and
automatic scaling. Mammoth's built-in orchestrator intelligently routes traffic,
taking into account hardware load, GPU memory, and caching states. You can
deploy and serve multiple models simultaneously across different hardware types
or versions without complex setup or duplication of infrastructure.
:::note
Mammoth powers advanced routing and scaling capabilities behind the scenes for
Modular's Dedicated Endpoint and Enterprise [editions](https://www.modular.com/pricing).
[Get in touch](https://www.modular.com/request-demo) to learn more about
how Mammoth can support your workloads at scale.
:::
## Access to Mammoth
If you need to serve one or more LLMs at scale with high performance and
minimal operational overhead, you can do so with Modular's Dedicated Endpoint
or Enterprise [editions](https://www.modular.com/pricing), which use Mammoth to
power routing and scaling capabilities.
Mammoth makes a difference when:
- You're running inference across heterogeneous GPU clusters (NVIDIA and AMD)
and need optimized, vendor-agnostic orchestration.
- You want a self-hosted, low-configuration deployment experience that works
out of the box, regardless of hardware or cloud provider.
- You need to dynamically scale workloads based on traffic and resource
availability, with fine-grained control over model placement and scheduling.
- You're managing fleets of models and want a unified serving layer without
duplicating infrastructure.
- You're working in a Kubernetes environment and want native integration that's
easy to operate and extend.
- You want to optimize total cost of ownership with cluster-level efficiency
features like disaggregated inference and KV cache-aware routing.
Additionally, because Mammoth is built on the MAX framework, you can use its
APIs and tools to customize and optimize every layer of the stack, from
high-level orchestration down to GPU kernels written in Mojo.
## How Mammoth works
Mammoth consists of a lightweight control plane, an intelligent
[orchestrator](/mammoth/orchestrator), and advanced optimizations such as
[disaggregated inference](/mammoth/disaggregated-inference), all working
together to efficiently deploy and run models across diverse hardware
environments.
Figure 2. An overview of the Mammoth components, including the control plane,
orchestrator, and disaggregated inference on separate prefill and decode nodes.
At the heart of Mammoth is its control plane, which takes care of setting up,
running, and scaling models automatically. Just provide the model ID (such as
`modularai/Llama-3.1-8B-Instruct`) or a path to the model on an external
storage provider like S3, and the control plane handles the rest.
You can interact with the control plane for:
- Model deployment: Launch models with a single command.
- Model management: Modify or delete deployed models.
- Multi-model orchestration: Run multiple models efficiently across shared
infrastructure.
- Scaling: Adjust replicas manually or let Mammoth autoscale intelligently.
- Resource allocation: Automatically allocate GPU resources to model deployment.
The Mammoth control plane extends the Kubernetes API with
[custom resource](https://kubernetes.io/docs/concepts/extend-kubernetes/api-extension/custom-resources/)
definitions (CRDs) and controls those resources with an
[operator](https://kubernetes.io/docs/concepts/extend-kubernetes/operator/).
When you create, update, or delete a resource, the control plane provisions
infrastructure, deploys or reconfigures models, and cleans up resources as
needed.
### Deploy models
With Mammoth running behind the scenes, deploying models in Modular's Dedicated
Endpoint and Enterprise editions is designed to be simple. You choose the model
you want to serve and define your resource requirements, and Mammoth's control
plane takes care of the rest. It automatically discovers available NVIDIA or
AMD GPUs, schedules the workload across the cluster, and scales as needed.
Whether you're serving a single large model or multiple models at once, Mammoth
handles orchestration and optimization so you can focus on your application
rather than infrastructure.
### Scale deployments
The control plane adjusts the deployment to the desired number of replicas and
allocates resources accordingly. For production use, intelligent autoscaling is
built in and configurable.
### Allocate resources
You can fine-tune resource allocation for each deployment. For example, with
[disaggregated inference](/mammoth/disaggregated-inference), you can assign
separate GPU resources to nodes that handle prefill and decode stages
independently.
## Become a design partner
Mammoth is currently only available through Modular's early access program where
we're actively partnering with select organizations as design partners. Design
partners get early access to new features and share feedback to help shape the
future of Mammoth.
Talk to an AI expert to learn more about how Mammoth can support your use case
and help you scale with confidence.
---
## Routing and orchestration
import ContactSection from '@site/src/components/ContactSection';
The orchestrator is responsible for distributing incoming inference
requests to the appropriate worker node in a cluster. This orchestration layer
plays a critical role in performance, load balancing, memory optimization, and
user experience.
Rather than simply forwarding requests to the next available worker, the
orchestrator uses configurable routing strategies to intelligently direct
traffic. Each routing strategy has trade-offs, and the ideal strategy depends
on the characteristics of your workload.
## How the orchestrator works
The orchestrator routes inference requests across distributed workers.
The orchestrator receives a prompt from an HTTP server, then analyzes the
request to extract information relevant to the specific routing strategy. The
orchestrator then selects a worker based on the specified routing algorithm and
current cluster state, proxies the request to the relevant worker, and streams
the response back to the user.
:::note
Orchestration with Mammoth is still in preview and some aspects may change as
we refine the implementation. Expect ongoing improvements and potential
adjustments based on feedback and performance optimizations.
:::
An overview of steps taken by the Mammoth orchestrator.
## Routing options
You can configure the routing strategy based on your deployment goals. For
stateless requests and broad load balancing, the round robin or least request
routing options work well. If you're optimizing for cache reuse or continuity in
conversation, prefix-aware, sticky sessions, or KV cache-aware routing may
offer significant performance gains. We also provide a random routing algorithm
for benchmarking or experimental purposes.
| Name | Strategy | Use case |
|-------------------|--------------------------------------------------------------------------|----------------------------------------------------------------|
| KV cache-aware | Routes based on shared tokens or document chunks in the KV cache | Repeated prompts in chatbots, agents, or RAG-style workflows |
| Least request | Sends requests to the worker with the fewest active requests | Mixed workloads with variable size or latency requirements |
| Prefix-aware | Uses consistent hashing on prompt prefixes to group similar requests | Prompts with shared templates or recurring task descriptions |
| Random | Selects a backend worker at random | Benchmarking and exposing latency variability |
| Round robin | Distributes requests evenly across all workers in sequential order | Stateless, uniform tasks without caching needs |
| Sticky session | Routes requests with the same session ID to the same worker | Session-based chat or apps needing memory and continuity |
### KV cache-aware
KV cache-aware routing manages requests based on the contents of the KV cache
on each worker. You might use KV cache-aware routing if you're running a
retrieval-augmented generation (RAG) system where most queries share common
document chunks or similar inputs, but not identical prefixes. KV cache-aware
routing is especially useful in the following scenarios:
- For high-throughput workloads with many repeating or similar tokens across
queries.
- When you want to minimize redundant computation across diverse,
overlapping queries.
### Least request
Least request routing sends new inference requests to the worker currently
handling the fewest active requests. This helps balance load dynamically and
reduces the chance of overloading any single worker. You might use least
request routing when serving a mix of both small and large generation tasks in
order to avoid piling multiple large requests on the same node. Least request
routing is especially useful in the following situations:
- When some workers receive heavier workloads or respond slower.
- For variable-length or unpredictable inference tasks.
- When you're optimizing for low tail latency.
### Prefix-aware
Prefix-aware routing, also known as consistent hashing, examines the prompt
prefix in an incoming request and routes it to the worker handling requests with
the same prefix. For example, if a support chatbot frequently receives the
prefix `{"role": "system", "content": "You are a helpful assistant."}` followed
by user-specific questions, prefix-aware routing keeps that common prefix cached
on a single node. When a worker becomes saturated with requests for a popular
prefix, the orchestrator automatically distributes the load by spilling over to
additional workers, maintaining partial cache locality while balancing traffic.
Prefix-aware request routing is especially useful in the following situations:
- When many users send queries that start with the same instructions or
template.
- If users frequently issue similar or identical prompts, like a
recurring task description or persona.
- In multi-turn conversations where session stickiness isn't enabled
### Random
Random routing selects a backend worker at random from the pool of available
endpoints for each incoming request. Random routing is useful when you want to
eliminate routing bias and observe average worker performance under distributed
load. It can help identify variability in latency or behavior across nodes
without favoring specific ones. Random routing is especially useful for
benchmarking use cases.
### Round robin
The round robin routing algorithm distributes incoming requests
evenly across all available workers in sequential order. Once the orchestrator
reaches the last worker in the list, it cycles back to the first. You might use
round robin routing if you're running a batch of isolated tasks that don't
require any request context or caching.
Round robin routing is especially useful in the following situations:
- For stateless or homogenous workloads where each request is independent.
- For testing environments or basic load distribution.
### Sticky session
Sticky session routing sends a user's requests to the same worker node for the
duration of their session. A session is identified by checking for a session ID
value in the request HTTP header. If this header is not present, the
orchestrator falls back to round robin routing.
You might use sticky session routing for a chatbot with user interaction, where
keeping their requests on the same worker node avoids reloading context
repeatedly. Sticky session routing is especially useful in the following
situations:
- When in-flight session state (ex. conversational memory) is maintained on the
server.
- For chatbots or streaming applications where continuity is important.
- When memory locality is key to performance.
## Become a design partner
To get the most out of prefix-aware routing and other advanced strategies, you
can explore [prefix caching](/max/serve/prefix-caching) and other serving layer
optimizations in MAX.
To get started with Mammoth's cluster-based deployments and optimize request
routing for your specific use case, you can reach out us and talk to an AI
expert.
---
## Common
```c
#include "max/c/common.h"
```
**Functions:**
### `M_version()`
> const char \*M\_version()
Gets the MAX Engine version.
* **Returns:**
A string containing the semantic version of the MAX Engine.
### `M_newStatus()`
> [M\_Status](types.md#_CPPv48M_Status) \*M\_newStatus()
Creates a new status object.
This is required as an argument for several functions, such as [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175) and [`M_compileModel()`](model.md#model_8h_1a88afca26a64b945885e1e1a0d09b5750). They will update the status object and you can check for errors with [`M_isError()`](#common_8h_1adb7a61f1c8f9c5e7964e8788cd437468) and get the status message with [`M_getError()`](#common_8h_1aa294beac43a0884cef8386e69a6bfc1b). For example:
```c
M_Status *status = M_newStatus();
M_RuntimeConfig *runtimeConfig = M_newRuntimeConfig();
M_RuntimeContext *context = M_newRuntimeContext(runtimeConfig, status);
if (M_isError(status)) {
logError(M_getError(status));
return EXIT_FAILURE;
}
```
* **Returns:**
A pointer to the new status object. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeStatus()`](#common_8h_1ab5067fd51a5696b3679f7f629d3329c4).
### `M_getError()`
> const char \*M\_getError(const [M\_Status](types.md#_CPPv48M_Status) \*status)
Gets an error message from the `M_Status` parameter.
You should call this only if [`M_isError()`](#common_8h_1adb7a61f1c8f9c5e7964e8788cd437468) is true.
* **Parameters:**
status – The status object for reporting errors and other messages.
* **Returns:**
A pointer to a null-terminated string containing the error message.
### `M_isError()`
> int M\_isError(const [M\_Status](types.md#_CPPv48M_Status) \*status)
Checks if status holds an error value.
* **Parameters:**
status – The status object for reporting errors and other messages.
* **Returns:**
`0` if there is no error, `1` otherwise.
### `M_freeStatus()`
> void M\_freeStatus([M\_Status](types.md#_CPPv48M_Status) \*status)
Deallocates the memory for the status object. No-op if `status` is `NULL`.
* **Parameters:**
status – The status object for reporting errors and other messages.
---
## Context
```c
#include "max/c/context.h"
```
**Functions:**
### `M_newRuntimeConfig()`
> [M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*M\_newRuntimeConfig()
Creates a new runtime config.
This configures runtime details such as the number of threads and log level.
By default, the config object’s number of threads will be set to `0`, which is internally used to refer to the number of physical processors in the first socket in the system. You can change this with `M_setNumThreads()`.
You need this as an argument for [`M_newRuntimeContext()`](#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175).
* **Returns:**
A pointer to the new runtime config. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeRuntimeConfig()`](#context_8h_1a47f7e22f7f71da9ab5fb3a1886911610).
### `M_freeRuntimeConfig()`
> void M\_freeRuntimeConfig([M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*config)
Deallocates the memory for a runtime config. No-op if `config` is `NULL`.
* **Parameters:**
config – The runtime config.
### `M_runtimeConfigAddDevice()`
> void M\_runtimeConfigAddDevice([M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*config, [M\_Device](types.md#_CPPv48M_Device) \*device)
Adds a device to be accessible from the runtime.
* **Parameters:**
* config – The runtime config.
* device – The device to add to the runtime config.
### `M_newRuntimeContext()`
> [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*M\_newRuntimeContext(const [M\_RuntimeConfig](types.md#_CPPv415M_RuntimeConfig) \*config, [M\_Status](types.md#_CPPv48M_Status) \*status)
Creates a runtime context.
The context is an application-level object that sets up various resources such as threadpool and allocators during inference. You need this before you can call [`M_compileModel()`](model.md#model_8h_1a88afca26a64b945885e1e1a0d09b5750).
It’s expected that there’s only one runtime context active in an inference session at a time. We recommended you create one context and use it throughout your application.
For example:
```c
M_Status *status = M_newStatus();
M_RuntimeConfig *runtimeConfig = M_newRuntimeConfig();
M_RuntimeContext *context = M_newRuntimeContext(runtimeConfig, status);
if (M_isError(status)) {
logError(M_getError(status));
return EXIT_FAILURE;
}
```
* **Parameters:**
* config – The runtime config, from [`M_newRuntimeConfig()`](#context_8h_1a963f1d4eefd812ba8691acf516007cfc).
* status – The status object for reporting errors. It is filled with an error message if construction of the runtime context fails.
* **Returns:**
A pointer to the runtime context object. On success, this is a valid pointer. On failure, this is a `NULL` pointer with an error message in the status. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeRuntimeContext()`](#context_8h_1a2434a11d8d65890c66f6b5516243a730).
### `M_freeRuntimeContext()`
> void M\_freeRuntimeContext([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context)
Deallocates the memory for a runtime context. No-op if `context` is `NULL`.
* **Parameters:**
context – The runtime context.
### `M_setDebugPrintOptions()`
> void M\_setDebugPrintOptions([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_ResultOutputStyle](types.md#_CPPv419M_ResultOutputStyle) style, unsigned int precision, const char \*directory)
Set the options for debugging printing of tensors when executing a model.
* **Parameters:**
* context – The runtime context.
* style – The way the data will be printed.
* precision – The floating point print out precision.
* directory – The directory to store binary output.
### `M_setMojoDefineBool()`
> void M\_setMojoDefineBool([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const char \*key, bool value)
Sets a mojo compile-time define with an boolean value.
* **Parameters:**
* context – The runtime context.
* key – The name of the define.
* value – The boolean to set the define to.
### `M_setMojoDefineInt()`
> void M\_setMojoDefineInt([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const char \*key, int value)
Sets a mojo compile-time define with an integer value.
* **Parameters:**
* context – The runtime context.
* key – The name of the define.
* value – The integer to set the define to.
### `M_setMojoDefineString()`
> void M\_setMojoDefineString([M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const char \*key, const char \*value)
Sets a mojo compile-time define with an string value.
* **Parameters:**
* context – The runtime context.
* key – The name of the define.
* value – The string to set the define to.
---
## C API
You can use the following C APIs to run inference with MAX Engine.
## API headers
Each of the following pages represents one of the C API header files:
* [Common](common.md)
* [`M_version()`](common.md#_CPPv49M_versionv)
* [`M_newStatus()`](common.md#_CPPv411M_newStatusv)
* [`M_getError()`](common.md#_CPPv410M_getErrorPK8M_Status)
* [`M_isError()`](common.md#_CPPv49M_isErrorPK8M_Status)
* [`M_freeStatus()`](common.md#_CPPv412M_freeStatusP8M_Status)
* [Context](context.md)
* [`M_newRuntimeConfig()`](context.md#_CPPv418M_newRuntimeConfigv)
* [`M_freeRuntimeConfig()`](context.md#_CPPv419M_freeRuntimeConfigP15M_RuntimeConfig)
* [`M_runtimeConfigAddDevice()`](context.md#_CPPv424M_runtimeConfigAddDeviceP15M_RuntimeConfigP8M_Device)
* [`M_newRuntimeContext()`](context.md#_CPPv419M_newRuntimeContextPK15M_RuntimeConfigP8M_Status)
* [`M_freeRuntimeContext()`](context.md#_CPPv420M_freeRuntimeContextP16M_RuntimeContext)
* [`M_setDebugPrintOptions()`](context.md#_CPPv422M_setDebugPrintOptionsP16M_RuntimeContext19M_ResultOutputStylejPKc)
* [`M_setMojoDefineBool()`](context.md#_CPPv419M_setMojoDefineBoolP16M_RuntimeContextPKcb)
* [`M_setMojoDefineInt()`](context.md#_CPPv418M_setMojoDefineIntP16M_RuntimeContextPKci)
* [`M_setMojoDefineString()`](context.md#_CPPv421M_setMojoDefineStringP16M_RuntimeContextPKcPKc)
* [Model](model.md)
* [`M_newCompileConfig()`](model.md#_CPPv418M_newCompileConfigv)
* [`M_setModelPath()`](model.md#_CPPv414M_setModelPathP15M_CompileConfigPKc)
* [`M_compileModel()`](model.md#_CPPv414M_compileModelPK16M_RuntimeContextPP15M_CompileConfigP8M_Status)
* [`M_waitForCompilation()`](model.md#_CPPv420M_waitForCompilationP20M_AsyncCompiledModelP8M_Status)
* [`M_compileModelSync()`](model.md#_CPPv418M_compileModelSyncPK16M_RuntimeContextPP15M_CompileConfigP8M_Status)
* [`M_initModel()`](model.md#_CPPv411M_initModelPK16M_RuntimeContextPK20M_AsyncCompiledModelPK17M_WeightsRegistryP8M_Status)
* [`M_waitForModel()`](model.md#_CPPv414M_waitForModelP12M_AsyncModelP8M_Status)
* [`M_executeModelSync()`](model.md#_CPPv418M_executeModelSyncPK16M_RuntimeContextP12M_AsyncModelP16M_AsyncTensorMapP8M_Status)
* [`M_freeModel()`](model.md#_CPPv411M_freeModelP12M_AsyncModel)
* [`M_freeCompiledModel()`](model.md#_CPPv419M_freeCompiledModelP20M_AsyncCompiledModel)
* [`M_freeCompileConfig()`](model.md#_CPPv419M_freeCompileConfigP15M_CompileConfig)
* [Tensor](tensor.md)
* [`M_newTensorSpec()`](tensor.md#_CPPv415M_newTensorSpecPK7int64_t7int64_t7M_DtypePKcPK8M_Device)
* [`M_isDynamicRanked()`](tensor.md#_CPPv417M_isDynamicRankedPK12M_TensorSpec)
* [`M_getDimAt()`](tensor.md#_CPPv410M_getDimAtPK12M_TensorSpec6size_t)
* [`M_getRank()`](tensor.md#_CPPv49M_getRankPK12M_TensorSpec)
* [`M_getDtype()`](tensor.md#_CPPv410M_getDtypePK12M_TensorSpec)
* [`M_getName()`](tensor.md#_CPPv49M_getNameP12M_TensorSpec)
* [`M_newAsyncTensorMap()`](tensor.md#_CPPv419M_newAsyncTensorMapPK16M_RuntimeContext)
* [`M_borrowTensorInto()`](tensor.md#_CPPv418M_borrowTensorIntoP16M_AsyncTensorMapPvPK12M_TensorSpecP8M_Status)
* [`M_getTensorByNameFrom()`](tensor.md#_CPPv421M_getTensorByNameFromP16M_AsyncTensorMapPKcP8M_Status)
* [`M_getTensorNumElements()`](tensor.md#_CPPv422M_getTensorNumElementsPK13M_AsyncTensor)
* [`M_getTensorType()`](tensor.md#_CPPv415M_getTensorTypePK13M_AsyncTensor)
* [`M_getTensorData()`](tensor.md#_CPPv415M_getTensorDataPK13M_AsyncTensor)
* [`M_getTensorSpec()`](tensor.md#_CPPv415M_getTensorSpecPK13M_AsyncTensor)
* [`M_getDeviceTypeFromSpec()`](tensor.md#_CPPv423M_getDeviceTypeFromSpecPK12M_TensorSpec)
* [`M_getDeviceIdFromSpec()`](tensor.md#_CPPv421M_getDeviceIdFromSpecPK12M_TensorSpec)
* [`M_getTensorDevice()`](tensor.md#_CPPv417M_getTensorDevicePK13M_AsyncTensor)
* [`M_copyTensorToDevice()`](tensor.md#_CPPv420M_copyTensorToDeviceP13M_AsyncTensorP8M_DeviceP8M_Status)
* [`M_freeTensor()`](tensor.md#_CPPv412M_freeTensorP13M_AsyncTensor)
* [`M_freeTensorNameArray()`](tensor.md#_CPPv421M_freeTensorNameArrayP17M_TensorNameArray)
* [`M_freeTensorSpec()`](tensor.md#_CPPv416M_freeTensorSpecP12M_TensorSpec)
* [`M_freeAsyncTensorMap()`](tensor.md#_CPPv420M_freeAsyncTensorMapP16M_AsyncTensorMap)
* [Types](types.md)
* [`M_Status`](types.md#_CPPv48M_Status)
* [`M_RuntimeConfig`](types.md#_CPPv415M_RuntimeConfig)
* [`M_RuntimeContext`](types.md#_CPPv416M_RuntimeContext)
* [`M_CompileConfig`](types.md#_CPPv415M_CompileConfig)
* [`M_AsyncCompiledModel`](types.md#_CPPv420M_AsyncCompiledModel)
* [`M_AsyncModel`](types.md#_CPPv412M_AsyncModel)
* [`M_AsyncTensor`](types.md#_CPPv413M_AsyncTensor)
* [`M_TensorNameArray`](types.md#_CPPv417M_TensorNameArray)
* [`M_TensorSpec`](types.md#_CPPv412M_TensorSpec)
* [`M_AsyncTensorMap`](types.md#_CPPv416M_AsyncTensorMap)
* [`M_WeightsRegistry`](types.md#_CPPv417M_WeightsRegistry)
* [`M_Device`](types.md#_CPPv48M_Device)
* [`M_Dtype`](types.md#_CPPv47M_Dtype)
* [`M_UNKNOWN`](types.md#_CPPv4N7M_Dtype9M_UNKNOWNE)
* [`mIsInteger`](types.md#_CPPv4N7M_Dtype10mIsIntegerE)
* [`mIsFloat`](types.md#_CPPv4N7M_Dtype8mIsFloatE)
* [`mIsComplex`](types.md#_CPPv4N7M_Dtype10mIsComplexE)
* [`mIsSigned`](types.md#_CPPv4N7M_Dtype9mIsSignedE)
* [`kIntWidthShift`](types.md#_CPPv4N7M_Dtype14kIntWidthShiftE)
* [`M_INT1`](types.md#_CPPv4N7M_Dtype6M_INT1E)
* [`M_UINT1`](types.md#_CPPv4N7M_Dtype7M_UINT1E)
* [`M_INT2`](types.md#_CPPv4N7M_Dtype6M_INT2E)
* [`M_UINT2`](types.md#_CPPv4N7M_Dtype7M_UINT2E)
* [`M_INT4`](types.md#_CPPv4N7M_Dtype6M_INT4E)
* [`M_UINT4`](types.md#_CPPv4N7M_Dtype7M_UINT4E)
* [`M_INT8`](types.md#_CPPv4N7M_Dtype6M_INT8E)
* [`M_UINT8`](types.md#_CPPv4N7M_Dtype7M_UINT8E)
* [`M_INT16`](types.md#_CPPv4N7M_Dtype7M_INT16E)
* [`M_UINT16`](types.md#_CPPv4N7M_Dtype8M_UINT16E)
* [`M_INT32`](types.md#_CPPv4N7M_Dtype7M_INT32E)
* [`M_UINT32`](types.md#_CPPv4N7M_Dtype8M_UINT32E)
* [`M_INT64`](types.md#_CPPv4N7M_Dtype7M_INT64E)
* [`M_UINT64`](types.md#_CPPv4N7M_Dtype8M_UINT64E)
* [`M_INT128`](types.md#_CPPv4N7M_Dtype8M_INT128E)
* [`M_UINT128`](types.md#_CPPv4N7M_Dtype9M_UINT128E)
* [`M_FLOAT4_E2M1FN`](types.md#_CPPv4N7M_Dtype15M_FLOAT4_E2M1FNE)
* [`M_FLOAT8_E8M0FNU`](types.md#_CPPv4N7M_Dtype16M_FLOAT8_E8M0FNUE)
* [`M_FLOAT8_E3M4`](types.md#_CPPv4N7M_Dtype13M_FLOAT8_E3M4E)
* [`M_FLOAT8_E4M3FN`](types.md#_CPPv4N7M_Dtype15M_FLOAT8_E4M3FNE)
* [`M_FLOAT8_E4M3FNUZ`](types.md#_CPPv4N7M_Dtype17M_FLOAT8_E4M3FNUZE)
* [`M_FLOAT8_E5M2`](types.md#_CPPv4N7M_Dtype13M_FLOAT8_E5M2E)
* [`M_FLOAT8_E5M2FNUZ`](types.md#_CPPv4N7M_Dtype17M_FLOAT8_E5M2FNUZE)
* [`M_FLOAT16`](types.md#_CPPv4N7M_Dtype9M_FLOAT16E)
* [`M_BFLOAT16`](types.md#_CPPv4N7M_Dtype10M_BFLOAT16E)
* [`M_FLOAT32`](types.md#_CPPv4N7M_Dtype9M_FLOAT32E)
* [`M_FLOAT64`](types.md#_CPPv4N7M_Dtype9M_FLOAT64E)
* [`M_BOOL`](types.md#_CPPv4N7M_Dtype6M_BOOLE)
* [`M_AllocatorType`](types.md#_CPPv415M_AllocatorType)
* [`kSystem`](types.md#_CPPv4N15M_AllocatorType7kSystemE)
* [`kCaching`](types.md#_CPPv4N15M_AllocatorType8kCachingE)
* [`M_ValueType`](types.md#_CPPv411M_ValueType)
* [`M_STRING_VALUE`](types.md#_CPPv4N11M_ValueType14M_STRING_VALUEE)
* [`M_DOUBLE_VALUE`](types.md#_CPPv4N11M_ValueType14M_DOUBLE_VALUEE)
* [`M_LONG_VALUE`](types.md#_CPPv4N11M_ValueType12M_LONG_VALUEE)
* [`M_BOOL_VALUE`](types.md#_CPPv4N11M_ValueType12M_BOOL_VALUEE)
* [`M_TENSOR_VALUE`](types.md#_CPPv4N11M_ValueType14M_TENSOR_VALUEE)
* [`M_LIST_VALUE`](types.md#_CPPv4N11M_ValueType12M_LIST_VALUEE)
* [`M_TUPLE_VALUE`](types.md#_CPPv4N11M_ValueType13M_TUPLE_VALUEE)
* [`M_DICT_VALUE`](types.md#_CPPv4N11M_ValueType12M_DICT_VALUEE)
* [`M_NONE_VALUE`](types.md#_CPPv4N11M_ValueType12M_NONE_VALUEE)
* [`M_UNKNOWN_VALUE`](types.md#_CPPv4N11M_ValueType15M_UNKNOWN_VALUEE)
* [`M_MOJO_VALUE`](types.md#_CPPv4N11M_ValueType12M_MOJO_VALUEE)
* [`M_PYTHON_MOJO_VALUE`](types.md#_CPPv4N11M_ValueType19M_PYTHON_MOJO_VALUEE)
* [`M_DeviceType`](types.md#_CPPv412M_DeviceType)
* [`M_HOST`](types.md#_CPPv4N12M_DeviceType6M_HOSTE)
* [`M_ACCELERATOR`](types.md#_CPPv4N12M_DeviceType13M_ACCELERATORE)
* [`M_ResultOutputStyle`](types.md#_CPPv419M_ResultOutputStyle)
* [`M_COMPACT`](types.md#_CPPv4N19M_ResultOutputStyle9M_COMPACTE)
* [`M_FULL`](types.md#_CPPv4N19M_ResultOutputStyle6M_FULLE)
* [`M_BINARY`](types.md#_CPPv4N19M_ResultOutputStyle8M_BINARYE)
* [`M_BINARY_MAX_CHECKPOINT`](types.md#_CPPv4N19M_ResultOutputStyle23M_BINARY_MAX_CHECKPOINTE)
* [`M_NONE`](types.md#_CPPv4N19M_ResultOutputStyle6M_NONEE)
## Async API usage
Our C API allows for compiling and executing models asynchronously. In general,
effective use of asynchronous APIs may be difficult, but rewarding for
performance. To help with this, we’re going to explain some important concepts
and mental models to keep in mind with the API.
Our APIs are async-safe unless stated otherwise, typically with a `Sync` in the
function identifier name. For example, we have `M_executeModel` and
[`M_executeModelSync()`](model.md#_CPPv418M_executeModelSyncPK16M_RuntimeContextP12M_AsyncModelP16M_AsyncTensorMapP8M_Status).
### Types
Our API describes the underlying async-holding types with a “value or error”
concept. Conceptually, this means that the type is in one of three states:
* `Constructed` - the value is not yet there, but there is no error
* `Available` - the value is there and ready for use
* `Error` - the value is not there and there is an error
### Synchronization points
When using async APIs, it is a good idea to be mindful of the synchronization
point APIs currently provided below. This is useful for discerning between the
`Constructed` and `Available` states mentioned above. After calling the
synchronization point, the input will never be in a `Constructed` state: it will
always resolve to either being `Available` or `Error`.
* [`M_waitForCompilation()`](model.md#_CPPv420M_waitForCompilationP20M_AsyncCompiledModelP8M_Status)
* [`M_waitForModel()`](model.md#_CPPv414M_waitForModelP12M_AsyncModelP8M_Status)
* `M_waitForTensors`
### Errors
Errors surface immediately when using our synchronous APIs. Otherwise, in the
case of async APIs, errors will not surface until the next synchronization
point. You can query the error message by calling [`M_getError()`](common.md#_CPPv410M_getErrorPK8M_Status).
---
## Model
```c
#include "max/c/model.h"
```
**Functions:**
### `M_newCompileConfig()`
> [M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*M\_newCompileConfig()
Creates an object you can use to configure model compilation.
You need `M_CompileConfig` as an argument for several functions, including [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7) and [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750).
* **Returns:**
A pointer to a new compilation configuration. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeCompileConfig()`](#model_8h_1abbf74b13adaf5bc8a0bb4d46c40688d9). This compilation configuration can only be used for a single compilation call. Any subsequent compilations must be passed a new `M_CompileConfig` (created by calling [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665) again).
### `M_setModelPath()`
> void M\_setModelPath([M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*compileConfig, const char \*path)
Sets the path to a model.
You must call this before you call [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). Otherwise, [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750) returns an error in `status`.
* **Parameters:**
* compileConfig – The compilation configuration for your model, from [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665).
* path – The path to your model. The model does not need to exist on the filesystem at this point. This follows the same semantics and expectations as `std::filesystem::path`.
### `M_compileModel()`
> [M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*M\_compileModel(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*\*compileConfig, [M\_Status](types.md#_CPPv48M_Status) \*status)
Compiles a model.
This immediately returns an `M_AsyncCompiledModel`, with compilation happening asynchronously. If you need to block to await compilation, you can then call [`M_waitForCompilation()`](#model_8h_1a8040a6488596f863c205d769d92ad013).
You must call [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7) before you call this. For example:
```c
M_CompileConfig *compileConfig = M_newCompileConfig();
M_setModelPath(compileConfig, modelPath);
M_AsyncCompiledModel *compiledModel =
M_compileModel(context, &compileConfig, status);
if (M_isError(status)) {
logError(M_getError(status));
return EXIT_FAILURE;
}
```
The `M_AsyncCompiledModel` returned here is not ready for inference yet. You need to then initialize the model with [`M_initModel()`](#model_8h_1a2dcb9570ae117602579182d8faed494a).
* **Parameters:**
* context – The runtime context, from [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175).
* compileConfig – Address of compilation configuration for your model created with [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665), and with the model set via [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7). Ownership of configuration is handed over to API.
* status – The status used to report errors in the case of failures during model compilation.
* **Returns:**
A pointer to an `M_AsyncCompiledModel`. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeCompiledModel()`](#model_8h_1a5b6846eb4d47d445eb65c305b1c81b1c). If the config is invalid, it returns a `NULL` pointer. If the model compilation fails, the pointer is `NULL` and the `status` parameter contains an error message. `compileConfig` will be reset to `NULL` after this call irrespective of status and cannot be reused, and any subsequent calls must take a new `M_CompileConfig`.
### `M_waitForCompilation()`
> void M\_waitForCompilation([M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*compiledModel, [M\_Status](types.md#_CPPv48M_Status) \*status)
Blocks execution until the model is compiled.
This waits for the async compiled model to be complete after calling [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750). When this function returns, the model is resolved to either a compiled model or an error.
* **Parameters:**
* compiledModel – The model received from [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750).
* status – The status used to report errors in the case of failures.
### `M_compileModelSync()`
> [M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*M\_compileModelSync(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*\*compileConfig, [M\_Status](types.md#_CPPv48M_Status) \*status)
Synchronously compiles a model.
Unlike [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750), this blocks until model compilation is complete. It returns an `M_AsyncCompiledModel` without needing to call [`M_waitForCompilation()`](#model_8h_1a8040a6488596f863c205d769d92ad013). All other setup and usage is identical to [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750).
* **Parameters:**
* context – The runtime context, from [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175).
* compileConfig – Address of compilation configuration for your model created with [`M_newCompileConfig()`](#model_8h_1a417e7a581c096ca26c36a1875163b665), and with the model set via [`M_setModelPath()`](#model_8h_1a03244f05c8a6092a55d3abc124ad90b7). Ownership of configuration is handed over to API.
* status – The status used to report errors in the case of failures during model compilation.
* **Returns:**
A pointer to an `M_AsyncCompiledModel`. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeCompiledModel()`](#model_8h_1a5b6846eb4d47d445eb65c305b1c81b1c). If the config is invalid, it returns a `NULL` pointer. If the model compilation fails, the pointer is `NULL` and the `status` parameter contains an error message. `compileConfig` will be reset to `NULL` after this call irrespective of status and cannot be reused, and any subsequent calls must take a new `M_CompileConfig`.
### `M_initModel()`
> [M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*M\_initModel(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, const [M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*compiledModel, const [M\_WeightsRegistry](types.md#_CPPv417M_WeightsRegistry) \*weightsRegistry, [M\_Status](types.md#_CPPv48M_Status) \*status)
Sets up a model for execution.
You can call this immediately after [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750)—you don’t need to wait for the async compilation.
This function also returns immediately with model initialization happening asynchronously. For example:
```c
M_AsyncModel *model = M_initModel(
context, compiledModel, weightsRegistry, status);
if (M_isError(status)) {
logError(M_getError(status));
return EXIT_FAILURE;
}
```
If you want to block until `M_AsyncModel` is initialized, you can call [`M_waitForModel()`](#model_8h_1a852bec3f80cebb5c06911091d5cab349), but that’s not necessary and you can immediately call [`M_executeModelSync()`](#model_8h_1a2ced4683834a77d0b943a6bc72d846d5).
* **Parameters:**
* context – The runtime context, from [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175).
* compiledModel – The compiled model, from [`M_compileModel()`](#model_8h_1a88afca26a64b945885e1e1a0d09b5750).
* weightsRegistry – A mapping from weights’ names to their data. The weights registry is used to update weights or otherwise pass weights to the model init block at runtime, without recompiling the model graph. If the model doesn’t use the weights registry, it is safe to pass as NULL
* status – The status used to report errors in the case of failures. The status contains an error only if the given context or compiled model is invalid. Other errors will not surface until the next synchronization point.
* **Returns:**
A pointer to an `M_AsyncModel` that holds an async value to a compiled model. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeModel()`](#model_8h_1a4094fa8e414f8b6a6563474f8840d33c). If model initialization fails, the `status` parameter contains an error message.
### `M_waitForModel()`
> void M\_waitForModel([M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*model, [M\_Status](types.md#_CPPv48M_Status) \*status)
Blocks execution until the model is initialized.
This waits for the model setup to finish in [`M_initModel()`](#model_8h_1a2dcb9570ae117602579182d8faed494a).
* **Parameters:**
* model – The model.
* status – The status used to report errors in the case of failures.
### `M_executeModelSync()`
> [M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*M\_executeModelSync(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context, [M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*initializedModel, [M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*inputs, [M\_Status](types.md#_CPPv48M_Status) \*status)
Executes a model synchronously.
The inputs and outputs are `M_AsyncTensorMap` objects to allow chaining of inference. This operation is blocking and waits until the output results are ready.
* **Parameters:**
* context – The runtime context.
* initializedModel – The model to execute, from [`M_initModel()`](#model_8h_1a2dcb9570ae117602579182d8faed494a). Although that function is async, you can pass the `M_AsyncModel` here immediately.
* inputs – The tensor inputs.
* status – The status used to report errors in the case of failures. This includes failures encountered while running the model; there is no need for an explicit synchronization point.
* **Returns:**
A pointer to an `M_AsyncTensorMap` that holds the output tensors. These tensors are in a resolved state. You are responsible for the memory associated with the pointer returned. You can deallocate the memory by calling [`M_freeAsyncTensorMap()`](tensor.md#tensor_8h_1a0ac9628dcba39c9977b7f7ff95d8781e). In the case that executing the model fails, the `status` parameter contains an error message.
### `M_freeModel()`
> void M\_freeModel([M\_AsyncModel](types.md#_CPPv412M_AsyncModel) \*model)
Deallocates the memory for the model. No-op if `model` is `NULL`.
* **Parameters:**
model – The model to deallocate.
### `M_freeCompiledModel()`
> void M\_freeCompiledModel([M\_AsyncCompiledModel](types.md#_CPPv420M_AsyncCompiledModel) \*model)
Deallocates the memory for the compiled model. No-op if `model` is `NULL`.
* **Parameters:**
model – The compiled model to deallocate.
### `M_freeCompileConfig()`
> void M\_freeCompileConfig([M\_CompileConfig](types.md#_CPPv415M_CompileConfig) \*config)
Deallocates the memory for the compile config. No-op if `config` is `NULL`.
* **Parameters:**
config – The compilation configuration to deallocate.
---
## Tensor
```c
#include "max/c/tensor.h"
```
**Functions:**
### `M_newTensorSpec()`
> [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*M\_newTensorSpec(const int64\_t \*shape, int64\_t rankSize, [M\_Dtype](types.md#_CPPv47M_Dtype) dtype, const char \*tensorName, const [M\_Device](types.md#_CPPv48M_Device) \*device)
Creates a tensor specification.
You need this in order to set the input tensors with [`M_borrowTensorInto()`](#tensor_8h_1a58a1646cfa1726b047b020c89eb7345c).
When storing tensor data in memory, we always use a diminishing stride size. That is, earlier dimensions in the shape have larger strides than later dimensions. For example, a C array declared as `int arr[1][2][3]` would have a shape specified as `{1, 2, 3}`.
* **Parameters:**
* shape – The shape of the tensor.
* rankSize – The rank size of the tensor.
* dtype – The datatype for the tensor.
* tensorName – The name for the tensor. This string gets copied as part of the operation of `M_newTensorSpec`, so your original string need not remain valid after the completion of this call.
* device – The device on which the tensor resides.
* **Returns:**
A pointer to the tensor spec. You are responsible for the memory associated with the pointer returned. The memory can be deallocated by calling [`M_freeTensorSpec()`](#tensor_8h_1af0b957daeba1760134c3f24079b53026).
### `M_isDynamicRanked()`
> int M\_isDynamicRanked(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec)
Returns if the given spec has a dynamic rank.
* **Parameters:**
spec – The tensor spec.
* **Returns:**
`1` if the rank is dynamic. `0` otherwise.
### `M_getDimAt()`
> int64\_t M\_getDimAt(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec, size\_t axis)
Gets the element at a particular axis.
* **Parameters:**
* spec – The tensor spec.
* axis – The requested axis
* **Returns:**
The dimension at requested axis if the spec and axis are valid and has static rank. Otherwise, `0`. A dimension equaling `kDynamicDimensionValue` indicates dynamic dimension e.g. batch-size of a model expecting a batched tensor.
### `M_getRank()`
> int64\_t M\_getRank(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec)
Gets the rank from the tensor spec.
* **Parameters:**
spec – The tensor spec.
* **Returns:**
The number of dimensions in the tensor spec if the spec is static and valid, `kDynamicRankValue` if dynamic. Otherwise, `0`.
### `M_getDtype()`
> [M\_Dtype](types.md#_CPPv47M_Dtype) M\_getDtype(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec)
Gets the datatype from the tensor spec.
* **Parameters:**
spec – The tensor spec.
* **Returns:**
The element type from the tensor spec if the tensor spec is valid. Otherwise, `M_UNKNOWN`.
### `M_getName()`
> const char \*M\_getName([M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec)
Gets the name of the tensor from the tensor spec.
* **Parameters:**
spec – The tensor spec.
* **Returns:**
A null-terminated string containing the name of the tensor if the `spec` is valid. Otherwise, `NULL`. The memory associated with the returned string is owned by `spec`.
### `M_newAsyncTensorMap()`
> [M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*M\_newAsyncTensorMap(const [M\_RuntimeContext](types.md#_CPPv416M_RuntimeContext) \*context)
Creates a map of tensor names to async tensors.
* **Parameters:**
context – The runtime context.
* **Returns:**
A pointer to the tensor map. You are responsible for the memory associated with the pointer returned. The memory can be deallocated by calling [`M_freeAsyncTensorMap()`](#tensor_8h_1a0ac9628dcba39c9977b7f7ff95d8781e).
### `M_borrowTensorInto()`
> void M\_borrowTensorInto([M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*tensors, void \*input, const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*tensorSpec, [M\_Status](types.md#_CPPv48M_Status) \*status)
Adds a tensor to the tensor map.
You are responsible for the lifetime of the input tensor data. Its data gets “borrowed” into the Tensor Map.
* **Parameters:**
* tensors – The tensor map, from [`M_newAsyncTensorMap()`](#tensor_8h_1a18039c6e6c1769b947120b27178306eb).
* input – The input tensor data.
* tensorSpec – The tensor spec, from [`M_newTensorSpec()`](#tensor_8h_1ab7546d4d0a22ae82134d200272e8f8f4). This gets copied as part of the operation of `M_borrowTensorInto`, so your original tensorSpec need not exist through the lifetime of the tensor map.
* status – The status object for reporting errors.
### `M_getTensorByNameFrom()`
> [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*M\_getTensorByNameFrom([M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*tensorMap, const char \*name, [M\_Status](types.md#_CPPv48M_Status) \*status)
Gets a tensor from the tensor map by name.
* **Parameters:**
* tensorMap – The tensor map.
* name – The name of the tensor.
* status – The status object for reporting errors.
* **Returns:**
A pointer to the tensor. You are responsible for the memory associated with the pointer returned. The memory can be deallocated by calling [`M_freeTensor()`](#tensor_8h_1a339008df4a10af5e8c01ae970598765c). The held tensor inside the return value is simply borrowed from the corresponding input `M_AsyncTensorMap`. If the tensor map or name are invalid, a `NULL` pointer is returned and the `status` parameter contains an error message.
### `M_getTensorNumElements()`
> size\_t M\_getTensorNumElements(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor)
Gets the number of elements for the tensor.
* **Parameters:**
tensor – The tensor which must not be `NULL`.
* **Returns:**
The number of elements for the given tensor.
### `M_getTensorType()`
> [M\_Dtype](types.md#_CPPv47M_Dtype) M\_getTensorType(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor)
Gets the corresponding `M_Dtype` for the tensor.
* **Parameters:**
tensor – The tensor which must not be `NULL`.
* **Returns:**
The corresponding `M_Dtype` for the tensor.
### `M_getTensorData()`
> const void \*M\_getTensorData(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor)
Gets a pointer to underlying data of the tensor.
* **Parameters:**
tensor – The tensor which must not be `NULL`.
* **Returns:**
A pointer to the underlying data of the tensor. This pointer is valid for the lifetime of the underlying tensor.
### `M_getTensorSpec()`
> [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*M\_getTensorSpec(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor)
Gets a Tensor Spec for the tensor.
* **Parameters:**
tensor – The tensor.
* **Returns:**
The tensor spec for the tensor if the tensor is valid. Otherwise, `NULL`.
### `M_getDeviceTypeFromSpec()`
> [M\_DeviceType](types.md#_CPPv412M_DeviceType) M\_getDeviceTypeFromSpec(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec)
Gets the device type from a tensor specification.
* **Parameters:**
spec – The tensor spec.
* **Returns:**
The device type (CPU or GPU).
### `M_getDeviceIdFromSpec()`
> int M\_getDeviceIdFromSpec(const [M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec)
Gets the device ID from a tensor specification.
* **Parameters:**
spec – The tensor spec.
* **Returns:**
The device ID. Returns `0` if the spec is invalid.
### `M_getTensorDevice()`
> [M\_Device](types.md#_CPPv48M_Device) \*M\_getTensorDevice(const [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor)
Gets the device on which a tensor resides.
* **Parameters:**
tensor – The tensor.
* **Returns:**
The device on which the tensor resides, or `NULL` if the tensor is invalid. The caller owns the returned device and must free it with `M_freeDevice()`.
### `M_copyTensorToDevice()`
> [M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*M\_copyTensorToDevice([M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor, [M\_Device](types.md#_CPPv48M_Device) \*device, [M\_Status](types.md#_CPPv48M_Status) \*status)
Copies a tensor to a different device.
Creates a copy of the tensor on the specified device.
* **Parameters:**
* tensor – The tensor to copy.
* device – The target device.
* status – The status object for reporting errors.
* **Returns:**
A pointer to the tensor on the target device. The caller owns the returned memory and must deallocate it by calling [`M_freeTensor()`](#tensor_8h_1a339008df4a10af5e8c01ae970598765c). Returns `NULL` if the operation fails, with an error message in the status.
### `M_freeTensor()`
> void M\_freeTensor([M\_AsyncTensor](types.md#_CPPv413M_AsyncTensor) \*tensor)
Deallocates the memory for the tensor. No-op if `tensor` is NULL.
* **Parameters:**
tensor – The tensor to deallocate.
### `M_freeTensorNameArray()`
> void M\_freeTensorNameArray([M\_TensorNameArray](types.md#_CPPv417M_TensorNameArray) \*names)
Deallocates the memory for the array of tensor names. No-op if `names` is `NULL`.
* **Parameters:**
names – The tensor names to deallocate.
### `M_freeTensorSpec()`
> void M\_freeTensorSpec([M\_TensorSpec](types.md#_CPPv412M_TensorSpec) \*spec)
Deallocates the memory for the tensor spec. No-op if `spec` is `NULL`.
* **Parameters:**
spec – The tensor spec to deallocate.
### `M_freeAsyncTensorMap()`
> void M\_freeAsyncTensorMap([M\_AsyncTensorMap](types.md#_CPPv416M_AsyncTensorMap) \*tensorMap)
Deallocates the memory for the tensor map. No-op if `tensorMap` is `NULL`.
* **Parameters:**
tensorMap – The tensor map to deallocate.
---
## Types
```c
#include "max/c/types.h"
```
**Typedefs:**
### `M_Status`
> typedef struct [M\_Status](#_CPPv48M_Status) M\_Status
Contains the success or failure of an API call.
In general, any API that may fail accepts a `M_Status` argument that is filled in with a meaningful error message on failure.
You can create this with [`M_newStatus()`](common.md#common_8h_1adb1ef3fc2e0bcdc8eb17cac3ce91835b). When you’re done, call [`M_freeStatus()`](common.md#common_8h_1ab5067fd51a5696b3679f7f629d3329c4).
### `M_RuntimeConfig`
> typedef struct [M\_RuntimeConfig](#_CPPv415M_RuntimeConfig) M\_RuntimeConfig
Specifies the MAX Engine configuration.
Configuration properties include the number of threads, artifact path, etc.
You can create this with [`M_newRuntimeConfig()`](context.md#context_8h_1a963f1d4eefd812ba8691acf516007cfc). When you’re done, call [`M_freeRuntimeConfig()`](context.md#context_8h_1a47f7e22f7f71da9ab5fb3a1886911610).
### `M_RuntimeContext`
> typedef struct [M\_RuntimeContext](#_CPPv416M_RuntimeContext) M\_RuntimeContext
Contains information that needs to be shared between APIs.
You can create this with [`M_newRuntimeContext()`](context.md#context_8h_1a46a6c670f73e1ce560f3c2cc1de93175). When you’re done, call [`M_freeRuntimeContext()`](context.md#context_8h_1a2434a11d8d65890c66f6b5516243a730).
### `M_CompileConfig`
> typedef struct [M\_CompileConfig](#_CPPv415M_CompileConfig) M\_CompileConfig
Specifies the configuration required for model compilation.
You can create this with [`M_newCompileConfig()`](model.md#model_8h_1a417e7a581c096ca26c36a1875163b665). When you’re done, call [`M_freeCompileConfig()`](model.md#model_8h_1abbf74b13adaf5bc8a0bb4d46c40688d9).
### `M_AsyncCompiledModel`
> typedef struct [M\_AsyncCompiledModel](#_CPPv420M_AsyncCompiledModel) M\_AsyncCompiledModel
Contains an async value to a compiled model.
`M_AsyncCompiledModel` can be passed to other APIs that accept compiled models as a function parameter. This async value will eventually resolve to a compiled model or an error in the case of compilation failure.
You can create this with [`M_compileModel()`](model.md#model_8h_1a88afca26a64b945885e1e1a0d09b5750). When you’re done, call [`M_freeCompiledModel()`](model.md#model_8h_1a5b6846eb4d47d445eb65c305b1c81b1c).
### `M_AsyncModel`
> typedef struct [M\_AsyncModel](#_CPPv412M_AsyncModel) M\_AsyncModel
Contains a future used for inference.
The future will resolve to a model that’s ready for inference.
You can create this with [`M_initModel()`](model.md#model_8h_1a2dcb9570ae117602579182d8faed494a). When you’re done, call [`M_freeModel()`](model.md#model_8h_1a4094fa8e414f8b6a6563474f8840d33c).
### `M_AsyncTensor`
> typedef struct [M\_AsyncTensor](#_CPPv413M_AsyncTensor) M\_AsyncTensor
Contains an async value to a tensor for inference.
You can get this from [`M_getTensorByNameFrom()`](tensor.md#tensor_8h_1a9522ad955454dbd2d044066dea2cad95). When you’re done, call [`M_freeTensor()`](tensor.md#tensor_8h_1a339008df4a10af5e8c01ae970598765c).
### `M_TensorNameArray`
> typedef struct [M\_TensorNameArray](#_CPPv417M_TensorNameArray) M\_TensorNameArray
Contains an array of tensor names of model inputs or outputs.
You can get this from `M_getInputNames()` and `M_getOutputNames()`. When you’re done, call [`M_freeTensorNameArray()`](tensor.md#tensor_8h_1a7fa5d2aff7f89143ae1905fc29b5b112).
### `M_TensorSpec`
> typedef struct [M\_TensorSpec](#_CPPv412M_TensorSpec) M\_TensorSpec
Contains the representation of a shape and an element type.
You can create this with [`M_newTensorSpec()`](tensor.md#tensor_8h_1ab7546d4d0a22ae82134d200272e8f8f4). When you’re done, call [`M_freeTensorSpec()`](tensor.md#tensor_8h_1af0b957daeba1760134c3f24079b53026).
### `M_AsyncTensorMap`
> typedef struct [M\_AsyncTensorMap](#_CPPv416M_AsyncTensorMap) M\_AsyncTensorMap
Contains a collection of tensors.
The collection of tensors is used to represent inputs and outputs when executing a model.
You can create this with [`M_newAsyncTensorMap()`](tensor.md#tensor_8h_1a18039c6e6c1769b947120b27178306eb). When you’re done, call [`M_freeAsyncTensorMap()`](tensor.md#tensor_8h_1a0ac9628dcba39c9977b7f7ff95d8781e).
### `M_WeightsRegistry`
> typedef struct [M\_WeightsRegistry](#_CPPv417M_WeightsRegistry) M\_WeightsRegistry
Maps unique weight names to their backing data.
### `M_Device`
> typedef struct [M\_Device](#_CPPv48M_Device) M\_Device
Contains a device handle.
A device represents a computational unit (CPU or GPU) that can execute operations and hold tensors.
You can create this with `M_newDevice()`. When you’re done, call `M_freeDevice()`.
**Enums:**
### `M_Dtype`
> enum M\_Dtype
Represents all data types supported by the framework.
Values:
#### `M_UNKNOWN`
> enumerator M\_UNKNOWN
#### `mIsInteger`
> enumerator mIsInteger
#### `mIsFloat`
> enumerator mIsFloat
#### `mIsComplex`
> enumerator mIsComplex
#### `mIsSigned`
> enumerator mIsSigned
Bit 0 encodes “isSigned”.
#### `kIntWidthShift`
> enumerator kIntWidthShift
#### `M_INT1`
> enumerator M\_INT1
#### `M_UINT1`
> enumerator M\_UINT1
#### `M_INT2`
> enumerator M\_INT2
#### `M_UINT2`
> enumerator M\_UINT2
#### `M_INT4`
> enumerator M\_INT4
#### `M_UINT4`
> enumerator M\_UINT4
#### `M_INT8`
> enumerator M\_INT8
#### `M_UINT8`
> enumerator M\_UINT8
#### `M_INT16`
> enumerator M\_INT16
#### `M_UINT16`
> enumerator M\_UINT16
#### `M_INT32`
> enumerator M\_INT32
#### `M_UINT32`
> enumerator M\_UINT32
#### `M_INT64`
> enumerator M\_INT64
#### `M_UINT64`
> enumerator M\_UINT64
#### `M_INT128`
> enumerator M\_INT128
#### `M_UINT128`
> enumerator M\_UINT128
#### `M_FLOAT4_E2M1FN`
> enumerator M\_FLOAT4\_E2M1FN
Bits 0 through 3 indicate the kind of FP value.
#### `M_FLOAT8_E8M0FNU`
> enumerator M\_FLOAT8\_E8M0FNU
Some slots are left blank here to enable us to support more lower precision types in the future.
#### `M_FLOAT8_E3M4`
> enumerator M\_FLOAT8\_E3M4
#### `M_FLOAT8_E4M3FN`
> enumerator M\_FLOAT8\_E4M3FN
#### `M_FLOAT8_E4M3FNUZ`
> enumerator M\_FLOAT8\_E4M3FNUZ
#### `M_FLOAT8_E5M2`
> enumerator M\_FLOAT8\_E5M2
#### `M_FLOAT8_E5M2FNUZ`
> enumerator M\_FLOAT8\_E5M2FNUZ
#### `M_FLOAT16`
> enumerator M\_FLOAT16
#### `M_BFLOAT16`
> enumerator M\_BFLOAT16
#### `M_FLOAT32`
> enumerator M\_FLOAT32
#### `M_FLOAT64`
> enumerator M\_FLOAT64
#### `M_BOOL`
> enumerator M\_BOOL
### `M_AllocatorType`
> enum M\_AllocatorType
Contains an `AllocatorType`. You can choose between kCaching and kSystem kCaching trades off higher memory usage for better performance. kSystem uses the default system allocator.
Values:
#### `kSystem`
> enumerator kSystem
#### `kCaching`
> enumerator kCaching
### `M_ValueType`
> enum M\_ValueType
Represents the type of a value.
Values:
#### `M_STRING_VALUE`
> enumerator M\_STRING\_VALUE
#### `M_DOUBLE_VALUE`
> enumerator M\_DOUBLE\_VALUE
#### `M_LONG_VALUE`
> enumerator M\_LONG\_VALUE
#### `M_BOOL_VALUE`
> enumerator M\_BOOL\_VALUE
#### `M_TENSOR_VALUE`
> enumerator M\_TENSOR\_VALUE
#### `M_LIST_VALUE`
> enumerator M\_LIST\_VALUE
#### `M_TUPLE_VALUE`
> enumerator M\_TUPLE\_VALUE
#### `M_DICT_VALUE`
> enumerator M\_DICT\_VALUE
#### `M_NONE_VALUE`
> enumerator M\_NONE\_VALUE
#### `M_UNKNOWN_VALUE`
> enumerator M\_UNKNOWN\_VALUE
#### `M_MOJO_VALUE`
> enumerator M\_MOJO\_VALUE
#### `M_PYTHON_MOJO_VALUE`
> enumerator M\_PYTHON\_MOJO\_VALUE
### `M_DeviceType`
> enum M\_DeviceType
Represents the type of device.
Values:
#### `M_HOST`
> enumerator M\_HOST
#### `M_ACCELERATOR`
> enumerator M\_ACCELERATOR
### `M_ResultOutputStyle`
> enum M\_ResultOutputStyle
Represents the result output style for debug printing.
Values:
#### `M_COMPACT`
> enumerator M\_COMPACT
#### `M_FULL`
> enumerator M\_FULL
#### `M_BINARY`
> enumerator M\_BINARY
#### `M_BINARY_MAX_CHECKPOINT`
> enumerator M\_BINARY\_MAX\_CHECKPOINT
#### `M_NONE`
> enumerator M\_NONE
---
## API references
import ListingCards from '@site/src/components/Listing/ListingCards';
export const cards = [
{
title: 'Python',
url: '/max/api/python',
description: 'The Python library API reference.'
},
{
title: 'Mojo',
url: '/mojo/lib',
description: 'The Mojo library API reference.'
},
{
title: 'REST',
url: '/max/api/serve',
description: 'The MAX serving REST API reference.'
}
]
---
## BackgroundRecorder
## `BackgroundRecorder` {#max.diagnostics.gpu.BackgroundRecorder}
> class max.diagnostics.gpu.BackgroundRecorder
Asynchronous GPU metrics collection and data export capabilities.
The `BackgroundRecorder` enables continuous monitoring of GPU performance metrics
without blocking the main application thread. It automatically samples GPU
statistics at one-second intervals in a separate process, making it ideal for
profiling long-running inference sessions or training workloads.
When used as a context manager, the recorder starts background collection upon
entry and stops collection upon exit. The collected statistics are then
available through the stats property as a time-series of GPU measurements.
```python
from max.diagnostics.gpu import BackgroundRecorder
with BackgroundRecorder() as recorder:
# Run your GPU workload here
run_inference_session()
for i, snapshot in enumerate(recorder.stats):
print(f"Sample {i}: {len(snapshot)} GPUs detected")
for gpu_id, gpu_stats in snapshot.items():
print(f" {gpu_id}: {gpu_stats.memory.used_bytes} bytes used")
```
### `stats` {#max.diagnostics.gpu.BackgroundRecorder.stats}
> property stats: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [GPUStats](GPUStats.md#max.diagnostics.gpu.GPUStats)]]
Time-series of GPU statistics collected during background recording.
**Returns:**
A list of dictionaries, where each dictionary represents GPU statistics
at a specific point in time. Each dictionary maps GPU identifiers to
their corresponding [`GPUStats`](GPUStats.md#max.diagnostics.gpu.GPUStats) objects.
**Raises:**
[RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If accessed before the recorder context has exited.
---
## GPUDiagContext
## `GPUDiagContext` {#max.diagnostics.gpu.GPUDiagContext}
> class max.diagnostics.gpu.GPUDiagContext
Context manager providing unified access to GPU diagnostic information across NVIDIA and AMD hardware.
This class automatically detects and initializes supported GPU vendor libraries
(NVML for NVIDIA, ROCm SMI for AMD) and provides a unified interface for
collecting diagnostic statistics from all available GPUs in the system.
```python
from max.diagnostics.gpu import GPUDiagContext
with GPUDiagContext() as ctx:
stats = ctx.get_stats()
for gpu_id, gpu_stats in stats.items():
print(f"GPU {gpu_id}: {gpu_stats.memory.used_bytes} bytes used")
```
### `get_stats()` {#max.diagnostics.gpu.GPUDiagContext.get_stats}
> get\_stats()
Retrieve current GPU statistics for all detected GPUs in the system.
**Returns:**
A dictionary mapping GPU identifiers to their current statistics.
NVIDIA GPUs are prefixed with `nv` (e.g., `nv0`, `nv1`) and AMD
GPUs are prefixed with `amd` (e.g., `amd0`, `amd1`).
---
## GPUStats
## `GPUStats` {#max.diagnostics.gpu.GPUStats}
> class max.diagnostics.gpu.GPUStats(memory, utilization)
Comprehensive GPU state snapshot containing memory and utilization metrics.
This class provides a complete view of a GPU’s current state, including
detailed memory usage statistics and utilization percentages. It serves
as the primary data structure returned by GPU diagnostic queries.
### `memory` {#max.diagnostics.gpu.GPUStats.memory}
> memory: [MemoryStats](MemoryStats.md#max.diagnostics.gpu.MemoryStats)
Current GPU compute and memory utilization percentages.
---
## MemoryStats
## `MemoryStats` {#max.diagnostics.gpu.MemoryStats}
> class max.diagnostics.gpu.MemoryStats(total\_bytes, free\_bytes, used\_bytes, reserved\_bytes)
Detailed GPU memory usage statistics including total, free, used, and reserved memory.
This class provides comprehensive memory information for a GPU, allowing
developers to monitor memory consumption and identify potential memory
bottlenecks during model inference or training.
### `free_bytes` {#max.diagnostics.gpu.MemoryStats.free_bytes}
> free\_bytes: [int](https://docs.python.org/3/library/functions.html#int)
Currently allocated GPU memory in bytes.
### `total_bytes` {#max.diagnostics.gpu.MemoryStats.total_bytes}
> total\_bytes: [int](https://docs.python.org/3/library/functions.html#int)
Currently available GPU memory in bytes.
### `used_bytes` {#max.diagnostics.gpu.MemoryStats.used_bytes}
> used\_bytes: [int](https://docs.python.org/3/library/functions.html#int)
Memory reserved by the driver, if available from the GPU vendor.
---
## UtilizationStats
## `UtilizationStats` {#max.diagnostics.gpu.UtilizationStats}
> class max.diagnostics.gpu.UtilizationStats(gpu\_usage\_percent, memory\_activity\_percent)
GPU compute and memory activity utilization percentages.
This class captures the current utilization levels of a GPU’s compute
units and memory subsystem, providing insights into how effectively
the GPU resources are being utilized during workload execution.
### `gpu_usage_percent` {#max.diagnostics.gpu.UtilizationStats.gpu_usage_percent}
> gpu\_usage\_percent: [int](https://docs.python.org/3/library/functions.html#int)
Current GPU compute utilization as a percentage (0-100).
### `memory_activity_percent` {#max.diagnostics.gpu.UtilizationStats.memory_activity_percent}
> memory\_activity\_percent: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
Memory controller activity percentage, if available from the GPU vendor.
---
## gpu
Real-time GPU monitoring and diagnostic capabilities for NVIDIA and AMD graphics
hardware.
The GPU diagnostics module provides comprehensive tools for monitoring graphics
processing unit performance, memory usage, and utilization metrics. It supports
both NVIDIA GPUs through NVML and AMD GPUs through ROCm SMI, offering unified
access to hardware statistics regardless of vendor. The API enables both
synchronous queries for immediate metrics and asynchronous background collection
for continuous monitoring during long-running inference sessions.
## Classes
* [`BackgroundRecorder`](/max/api/python/diagnostics/gpu/BackgroundRecorder):
Asynchronous GPU metrics collection.
* [`GPUDiagContext`](/max/api/python/diagnostics/gpu/GPUDiagContext):
Context manager providing unified access to GPU diagnostic information across
NVIDIA and AMD hardware.
* [`GPUStats`](/max/api/python/diagnostics/gpu/GPUStats): Comprehensive
GPU state snapshot containing memory and utilization statistics.
* [`MemoryStats`](/max/api/python/diagnostics/gpu/MemoryStats): Detailed
GPU memory usage statistics including total, free, used, and reserved memory.
* [`UtilizationStats`](/max/api/python/diagnostics/gpu/UtilizationStats):
GPU compute and memory activity utilization percentages.
---
## driver
Exposes APIs for interacting with hardware, such as allocating tensors on a GPU
and moving tensors between the CPU and GPU. It provides interfaces for memory
management, device properties, and hardware monitoring. Through these APIs, you
can control data placement, track resource utilization, and configure device
settings for optimal performance.
For example, you can use the following code to use an accelerator if one is
available, otherwise use the CPU:
```python
from max import driver
device = driver.CPU() if driver.accelerator_count() == 0 else driver.Accelerator()
print(f"Using {device} device")
```
## `Accelerator` {#max.driver.Accelerator}
> class max.driver.Accelerator(\*args, \*\*kwargs)
## `Buffer` {#max.driver.Buffer}
> class max.driver.Buffer(\*args, \*\*kwargs)
Device-resident buffer representation.
Allocates memory onto a given device with the provided shape and dtype.
Buffers can be sliced to provide strided views of the underlying memory,
but any buffers input into model execution must be contiguous.
Supports numpy-style slicing but does not currently support setting
items across multiple indices.
```python
from max import driver
from max.dtype import DType
cpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.float32)
# Create a buffer on GPU
gpu = driver.Accelerator()
gpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.float32, device=gpu)
```
**Parameters:**
* dtype ([DType](dtype.md#max.dtype.DType)) – Data type of buffer elements.
* shape (Sequence\[[int](https://docs.python.org/3/library/functions.html#int)]) – Tuple of positive, non-zero integers denoting the buffer shape.
* device ([Device](#max.driver.Device), optional) – Device to allocate buffer onto. Defaults to the CPU.
* pinned ([bool](https://docs.python.org/3/library/functions.html#bool), optional) – If True, memory is page-locked (pinned). Defaults to False.
* stream ([DeviceStream](#max.driver.DeviceStream), optional) – Stream to associate the buffer with.
### `contiguous()` {#max.driver.Buffer.contiguous}
> contiguous()
Creates a contiguous copy of the parent buffer.
**Parameters:**
self ([Buffer](#max.driver.Buffer))
**Return type:**
[Buffer](#max.driver.Buffer)
### `copy()` {#max.driver.Buffer.copy}
> copy(self, stream: [max.driver.DeviceStream](#max.driver.DeviceStream)) → [max.driver.Buffer](#max.driver.Buffer)
> copy(self, device: [max.driver.Device](#max.driver.Device) | [None](https://docs.python.org/3/library/constants.html#None) = None) → [max.driver.Buffer](#max.driver.Buffer)
Overloaded function.
1. `copy(self, stream: max.driver.DeviceStream) -> max.driver.Buffer`
> Creates a deep copy on the device associated with the stream.
> Args:
> : stream (DeviceStream): The stream to associate the new buffer with.
> Returns:
> : Buffer: A new buffer that is a copy of this buffer.
2. `copy(self, device: max.driver.Device | None = None) -> max.driver.Buffer`
> Creates a deep copy on an optionally given device.
> If device is None (default), a copy is created on the same device.
>
> ```python
> from max import driver
> from max.dtype import DType
>
> cpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.bfloat16, device=driver.CPU())
> cpu_copy = cpu_buffer.copy()
>
> # Copy to GPU
> gpu = driver.Accelerator()
> gpu_copy = cpu_buffer.copy(device=gpu)
> ```
> Args:
> : device (Device, optional): The device to create the copy on.
> : Defaults to None (same device).
> Returns:
> : Buffer: A new buffer that is a copy of this buffer.
### `device` {#max.driver.Buffer.device}
> property device
Device on which tensor is resident.
### `disable_auto_sync()` {#max.driver.Buffer.disable_auto_sync}
> disable\_auto\_sync(self) → [None](https://docs.python.org/3/library/constants.html#None)
Disables automatic synchronization for asynchronous operations on this buffer.
:::caution Caution
This is an experimental feature that may be unstable. It also
requires special care from the user to ensure proper synchronization.
:::
By default, certain operations on buffers cause synchronization, such
as when trying to access a buffer on the host through to\_numpy.
However the default synchronization is quite conservative and often
ends up waiting on more than what is strictly needed.
This function disables the default synchronization method and enables
mark\_as\_ready(), which allows for a finer control of what is waited on
when a buffer needs to be synchronized.
```python
# Assuming we have 3 buffers of the same sizes, a, b and c
# Default case with auto-synchronization
a.to(b) # 1
a.to(c) # 2
# Will wait on 1 and 2
b.to_numpy()
# Disabled synchronization
a.disable_auto_sync()
a.to(b) # 1
a.to(c) # 2
# Doesn't wait on 1 or 2, data in b could be invalid
b.to_numpy()
# Disabled synchronization with mark_as_ready
a.disable_auto_sync()
a.to(b) # 1
b.mark_as_ready()
a.to(c) # 2
# Wait on 1 but not on 2
b.to_numpy()
```
### `dtype` {#max.driver.Buffer.dtype}
> property dtype
DType of constituent elements in tensor.
### `element_size` {#max.driver.Buffer.element_size}
> property element\_size
Return the size of the element type in bytes.
### `from_dlpack()` {#max.driver.Buffer.from_dlpack}
> from\_dlpack(\*, copy=None)
Create a buffer from an object implementing the dlpack protocol.
This usually does not result in a copy, and the producer of the object
retains ownership of the underlying memory.
### `from_numpy()` {#max.driver.Buffer.from_numpy}
> from\_numpy()
Creates a buffer from a provided numpy array on the host device.
The underlying data is not copied unless the array is noncontiguous. If
it is, a contiguous copy will be returned.
### `inplace_copy_from()` {#max.driver.Buffer.inplace_copy_from}
> inplace\_copy\_from(src)
Copy the contents of another buffer into this one.
These buffers may be on different devices.
Requires that both buffers are contiguous and have same size.
### `is_contiguous` {#max.driver.Buffer.is_contiguous}
> property is\_contiguous
Whether or not buffer is contiguously allocated in memory. Returns
false if the buffer is a non-contiguous slice.
Currently, we consider certain situations that are contiguous as
non-contiguous for the purposes of our engine, such as when a buffer
has negative steps.
### `is_host` {#max.driver.Buffer.is_host}
> property is\_host
Whether or not buffer is host-resident. Returns false for GPU buffers,
true for CPU buffers.
```python
from max import driver
from max.dtype import DType
cpu_buffer = driver.Buffer(shape=[2, 3], dtype=DType.bfloat16, device=driver.CPU())
print(cpu_buffer.is_host)
```
### `item()` {#max.driver.Buffer.item}
> item(self) → [Any](https://docs.python.org/3/library/typing.html#typing.Any)
Returns the scalar value at a given location. Currently
implemented only for zero-rank buffers. The return type is
converted to a Python built-in type.
### `mark_as_ready()` {#max.driver.Buffer.mark_as_ready}
> mark\_as\_ready(self) → [None](https://docs.python.org/3/library/constants.html#None)
Establishes a synchronization point for buffers with disabled auto-sync.
:::caution Caution
This is an experimental feature that may be unstable. It also
requires special care from the user to ensure proper synchronization.
:::
This method can only be called on buffers with disabled synchronization
through disable\_auto\_sync().
It instructs max that whenever it needs to wait on this buffer it
should only wait to the point where this was called.
It can be called multiple times, but it will override a previous
synchronization point with the new one.
Refer to the disable\_auto\_sync() documentation for more details and examples.
### `mmap()` {#max.driver.Buffer.mmap}
> mmap(dtype, shape, mode='copyonwrite', offset=0)
### `num_elements` {#max.driver.Buffer.num_elements}
> property num\_elements
Returns the number of elements in this buffer.
Rank-0 buffers have 1 element by convention.
### `pinned` {#max.driver.Buffer.pinned}
> property pinned
Whether or not the underlying memory is pinned (page-locked).
### `rank` {#max.driver.Buffer.rank}
> property rank
Buffer rank.
### `scalar` {#max.driver.Buffer.scalar}
> scalar = \
### `shape` {#max.driver.Buffer.shape}
> property shape
Shape of buffer.
### `stream` {#max.driver.Buffer.stream}
> property stream
Stream to which tensor is bound.
### `to()` {#max.driver.Buffer.to}
> to(self, device: [max.driver.Device](#max.driver.Device)) → [max.driver.Buffer](#max.driver.Buffer)
> to(self, stream: [max.driver.DeviceStream](#max.driver.DeviceStream)) → [max.driver.Buffer](#max.driver.Buffer)
> to(self, devices: [collections.abc.Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[max.driver.Device](#max.driver.Device)]) → [list](https://docs.python.org/3/library/stdtypes.html#list)\[[max.driver.Buffer](#max.driver.Buffer)]
> to(self, streams: [collections.abc.Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[max.driver.DeviceStream](#max.driver.DeviceStream)]) → [list](https://docs.python.org/3/library/stdtypes.html#list)\[[max.driver.Buffer](#max.driver.Buffer)]
Overloaded function.
1. `to(self, device: max.driver.Device) -> max.driver.Buffer`
> Return a buffer that’s guaranteed to be on the given device.
> The buffer is only copied if the requested device is different from the
> device upon which the buffer is already resident.
2. `to(self, stream: max.driver.DeviceStream) -> max.driver.Buffer`
> Return a buffer that’s guaranteed to be on the given device and associated
> with the given stream.
> The buffer is only copied if the requested device is different from the
> device upon which the buffer is already resident. If the destination
> stream is on the same device, then a new reference to the same buffer is
> returned.
3. `to(self, devices: collections.abc.Sequence[max.driver.Device]) -> list[max.driver.Buffer]`
> Return a list of buffers that are guaranteed to be on the given devices.
> The buffers are only copied if the requested devices are different from the
> device upon which the buffer is already resident.
4. `to(self, streams: collections.abc.Sequence[max.driver.DeviceStream]) -> list[max.driver.Buffer]`
> Return a list of buffers that are guaranteed to be on the given streams.
> The buffers are only copied if the requested streams are different from the
> stream upon which the buffer is already resident.
### `to_numpy()` {#max.driver.Buffer.to_numpy}
> to\_numpy()
Converts the buffer to a numpy array.
If the buffer is not on the host, a copy will be issued.
### `view()` {#max.driver.Buffer.view}
> view(dtype, shape=None)
Return a new buffer with the given type and shape that shares the underlying memory.
If the shape is not given, it will be deduced if possible, or a
ValueError is raised.
### `zeros` {#max.driver.Buffer.zeros}
> zeros = \
## `CPU` {#max.driver.CPU}
> class max.driver.CPU(\*args, \*\*kwargs)
## `DLPackArray` {#max.driver.DLPackArray}
> class max.driver.DLPackArray(\*args, \*\*kwargs)
## `Device` {#max.driver.Device}
> class max.driver.Device
### `api` {#max.driver.Device.api}
> property api
Returns the API used to program the device.
Possible values are:
* `cpu` for host devices.
* `cuda` for NVIDIA GPUs.
* `hip` for AMD GPUs.
```python
from max import driver
device = driver.CPU()
device.api
```
### `architecture_name` {#max.driver.Device.architecture_name}
> property architecture\_name
Returns the architecture name of the device.
Examples of possible values:
* `gfx90a`, `gfx942` for AMD GPUs.
* `sm_80`, `sm_86` for NVIDIA GPUs.
* CPU devices raise an exception.
```python
from max import driver
device = driver.Accelerator()
device.architecture_name
```
### `can_access()` {#max.driver.Device.can_access}
> can\_access(self, other: [max.driver.Device](#max.driver.Device)) → [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if this device can directly access memory of another device.
```python
from max import driver
gpu0 = driver.Accelerator(id=0)
gpu1 = driver.Accelerator(id=1)
if gpu0.can_access(gpu1):
print("GPU0 can directly access GPU1 memory.")
```
**Parameters:**
other ([Device](#max.driver.Device)) – The other device to check peer access against.
### `cpu` {#max.driver.Device.cpu}
> cpu = \
### `default_stream` {#max.driver.Device.default_stream}
> property default\_stream
Returns the default stream for this device.
The default stream is initialized when the device object is created.
**Returns:**
The default execution stream for this device.
**Return type:**
[DeviceStream](#max.driver.DeviceStream)
### `id` {#max.driver.Device.id}
> property id
Returns a zero-based device id. For a CPU device this is always 0.
For GPU accelerators this is the id of the device relative to this host.
Along with the `label`, an id can uniquely identify a device,
e.g. `gpu:0`, `gpu:1`.
```python
from max import driver
device = driver.Accelerator()
device_id = device.id
```
### `is_host` {#max.driver.Device.is_host}
> property is\_host
Whether this device is the CPU (host) device.
```python
from max import driver
device = driver.CPU()
device.is_host
```
### `label` {#max.driver.Device.label}
> property label
Returns device label.
Possible values are:
* `cpu` for host devices.
* `gpu` for accelerators.
```python
from max import driver
device = driver.CPU()
device.label
```
### `stats` {#max.driver.Device.stats}
> property stats
Returns utilization data for the device.
```python
from max import driver
device = driver.CPU()
stats = device.stats
```
**Returns:**
A dictionary containing device utilization statistics.
### `synchronize()` {#max.driver.Device.synchronize}
> synchronize(self) → [None](https://docs.python.org/3/library/constants.html#None)
Ensures all operations on this device complete before returning.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If any enqueued operations had an internal error.
## `DeviceSpec` {#max.driver.DeviceSpec}
> class max.driver.DeviceSpec(id, device\_type='cpu')
Specification for a device, containing its ID and type.
This class provides a way to specify device parameters like ID and type (CPU/GPU)
for creating Device instances.
**Parameters:**
* id ([int](https://docs.python.org/3/library/functions.html#int))
* device\_type ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['cpu', 'gpu'])
id ([int](https://docs.python.org/3/library/functions.html#int))
### `cpu()` {#max.driver.DeviceSpec.cpu}
> static cpu(id=-1)
Creates a CPU device specification.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
### `device_type` {#max.driver.DeviceSpec.device_type}
> device\_type: [Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['cpu', 'gpu'] = 'cpu'
Type of specified device.
### `id` {#max.driver.DeviceSpec.id}
> id: [int](https://docs.python.org/3/library/functions.html#int)
Provided id for this device.
## `DeviceStream` {#max.driver.DeviceStream}
> class max.driver.DeviceStream(\*args, \*\*kwargs)
Provides access to a stream of execution on a device.
A stream represents a sequence of operations that will be executed in order.
Multiple streams on the same device can execute concurrently.
```python
from max import driver
# Create a default accelerator device
device = driver.Accelerator()
# Get the default stream for the device
stream = device.default_stream
# Create a new stream of execution on the device
new_stream = driver.DeviceStream(device)
```
### `device` {#max.driver.DeviceStream.device}
> property device
The device this stream is executing on.
### `synchronize()` {#max.driver.DeviceStream.synchronize}
> synchronize(self) → [None](https://docs.python.org/3/library/constants.html#None)
Ensures all operations on this stream complete before returning.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If any enqueued operations had an internal error.
### `wait_for()` {#max.driver.DeviceStream.wait_for}
> wait\_for(self, stream: [max.driver.DeviceStream](#max.driver.DeviceStream)) → [None](https://docs.python.org/3/library/constants.html#None)
> wait\_for(self, device: [max.driver.Device](#max.driver.Device)) → [None](https://docs.python.org/3/library/constants.html#None)
Overloaded function.
1. `wait_for(self, stream: max.driver.DeviceStream) -> None`
> Ensures all operations on the other stream complete before future work
> submitted to this stream is scheduled.
> Args:
> : stream (DeviceStream): The stream to wait for.
2. `wait_for(self, device: max.driver.Device) -> None`
> Ensures all operations on device’s default stream complete before
> future work submitted to this stream is scheduled.
> Args:
> : device (Device): The device whose default stream to wait for.
## `accelerator_api()` {#max.driver.accelerator_api}
> max.driver.accelerator\_api()
Returns the API used to program the accelerator.
## `accelerator_architecture_name()` {#max.driver.accelerator_architecture_name}
> max.driver.accelerator\_architecture\_name()
Returns the architecture name of the accelerator device.
## `calculate_virtual_device_count()` {#max.driver.calculate_virtual_device_count}
> max.driver.calculate\_virtual\_device\_count(\*device\_spec\_lists)
Calculate the minimum virtual device count needed for the given device specs.
**Parameters:**
\*device\_spec\_lists ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceSpec](#max.driver.DeviceSpec)]) – One or more lists of DeviceSpec objects (e.g., main devices
and draft devices)
**Returns:**
The minimum number of virtual devices needed (max GPU ID + 1), or 1 if no GPUs
## `calculate_virtual_device_count_from_cli()` {#max.driver.calculate_virtual_device_count_from_cli}
> max.driver.calculate\_virtual\_device\_count\_from\_cli(\*device\_inputs)
Calculate virtual device count from raw CLI inputs (before parsing).
This helper works with the raw device input strings or lists before they’re
parsed into DeviceSpec objects. Used when virtual device mode needs to be
enabled before device validation occurs.
**Parameters:**
\*device\_inputs ([str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – One or more raw device inputs - either strings like “gpu:0,1,2”
or lists of integers like \[0, 1, 2]
**Returns:**
The minimum number of virtual devices needed (max GPU ID + 1), or 1 if no GPUs
## `load_devices()` {#max.driver.load_devices}
> max.driver.load\_devices(device\_specs)
Initialize and return a list of devices, given a list of device specs.
## `load_max_buffer()` {#max.driver.load_max_buffer}
> max.driver.load\_max\_buffer(path)
Experimental method for loading serialized MAX buffers.
Max buffers can be exported by creating a graph and calling Value.print()
with the BINARY\_MAX\_CHECKPOINT option.
**Parameters:**
path ([PathLike](https://docs.python.org/3/library/os.html#os.PathLike)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) – Path to buffer (should end with .max)
**Returns:**
A Buffer created from the path. The shape and dtype are read
from the file.
**Raises:**
ValueError if the file format is not the MAX checkpoint format. –
**Return type:**
[Buffer](#max.driver.Buffer)
## `scan_available_devices()` {#max.driver.scan_available_devices}
> max.driver.scan\_available\_devices()
Returns all accelerators if available, else return cpu.
## `accelerator_count()` {#max.driver.accelerator_count}
> max.driver.accelerator\_count() → [int](https://docs.python.org/3/library/functions.html#int)
Returns number of accelerator devices available.
---
## dtype
Provides data type definitions for tensors in MAX Engine. These data types are
essential for defining the precision and memory layout of tensor data when
working with machine learning models.
This module defines the [`DType`](#max.dtype.DType) enum, which represents all supported tensor
data types in MAX Engine, including:
* Integer types (signed and unsigned): `int8` | `uint8` | `int16` | `uint16` | `int32` | `uint32` | `int64` | `uint64`
* Floating-point types (`float8` variants): `float16` | `bfloat16` | `float32` | `float64`
* Boolean type: `bool`
The module also provides utilities for converting between MAX Engine data types
and [NumPy dtypes](https://numpy.org/doc/stable/user/basics.types.html), making
it easy to interoperate with the NumPy ecosystem.
```python
import numpy as np
from max.dtype import DType
tensor = np.zeros((2, 3), dtype=DType.float32.to_numpy())
# Convert NumPy dtype to MAX DType
array = np.ones((4, 4), dtype=np.float16)
max_dtype = DType.from_numpy(array.dtype)
# Check properties of data types
is_float = DType.float32.is_float() # True
is_int = DType.int64.is_integral() # True
size = DType.float64.size_in_bytes # 8
```
## `DType` {#max.dtype.DType}
> class max.dtype.DType(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
The tensor data type.
### `align` {#max.dtype.DType.align}
> property align
Returns the alignment requirement of the data type in bytes.
The alignment specifies the memory boundary that values of this data type
must be aligned to for optimal performance and correctness.
### `bfloat16` {#max.dtype.DType.bfloat16}
> bfloat16 = 80
### `bool` {#max.dtype.DType.bool}
> bool = 1
### `float16` {#max.dtype.DType.float16}
> float16 = 79
### `float32` {#max.dtype.DType.float32}
> float32 = 81
### `float4_e2m1fn` {#max.dtype.DType.float4_e2m1fn}
> float4\_e2m1fn = 64
### `float64` {#max.dtype.DType.float64}
> float64 = 82
### `float8_e4m3fn` {#max.dtype.DType.float8_e4m3fn}
> float8\_e4m3fn = 75
### `float8_e4m3fnuz` {#max.dtype.DType.float8_e4m3fnuz}
> float8\_e4m3fnuz = 76
### `float8_e5m2` {#max.dtype.DType.float8_e5m2}
> float8\_e5m2 = 77
### `float8_e5m2fnuz` {#max.dtype.DType.float8_e5m2fnuz}
> float8\_e5m2fnuz = 78
### `float8_e8m0fnu` {#max.dtype.DType.float8_e8m0fnu}
> float8\_e8m0fnu = 73
### `from_numpy()` {#max.dtype.DType.from_numpy}
> from\_numpy()
Converts a NumPy dtype to the corresponding DType.
**Parameters:**
dtype (np.dtype) – The NumPy dtype to convert.
**Returns:**
The corresponding DType enum value.
**Return type:**
[DType](#max.dtype.DType)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input dtype is not supported.
### `int16` {#max.dtype.DType.int16}
> int16 = 137
### `int32` {#max.dtype.DType.int32}
> int32 = 139
### `int64` {#max.dtype.DType.int64}
> int64 = 141
### `int8` {#max.dtype.DType.int8}
> int8 = 135
### `is_float()` {#max.dtype.DType.is_float}
> is\_float(self) → [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if the data type is a floating-point type.
### `is_float8()` {#max.dtype.DType.is_float8}
> is\_float8(self) → [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if the data type is an 8-bit floating-point type.
### `is_half()` {#max.dtype.DType.is_half}
> is\_half(self) → [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if the data type is a half-precision floating-point type.
### `is_integral()` {#max.dtype.DType.is_integral}
> is\_integral(self) → [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if the data type is an integer type.
### `is_signed_integral()` {#max.dtype.DType.is_signed_integral}
> is\_signed\_integral(self) → [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if the data type is a signed integer type.
### `is_unsigned_integral()` {#max.dtype.DType.is_unsigned_integral}
> is\_unsigned\_integral(self) → [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if the data type is an unsigned integer type.
### `size_in_bits` {#max.dtype.DType.size_in_bits}
> property size\_in\_bits
Returns the size of the data type in bits.
This indicates how many bits are required to store a single value
of this data type in memory.
### `size_in_bytes` {#max.dtype.DType.size_in_bytes}
> property size\_in\_bytes
Returns the size of the data type in bytes.
This indicates how many bytes are required to store a single value
of this data type in memory.
### `to_numpy()` {#max.dtype.DType.to_numpy}
> to\_numpy()
Converts this `DType` to the corresponding NumPy dtype.
**Returns:**
The corresponding NumPy dtype object.
**Return type:**
[DType](#max.dtype.DType)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the dtype is not supported.
### `uint16` {#max.dtype.DType.uint16}
> uint16 = 136
### `uint32` {#max.dtype.DType.uint32}
> uint32 = 138
### `uint64` {#max.dtype.DType.uint64}
> uint64 = 140
### `uint8` {#max.dtype.DType.uint8}
> uint8 = 134
## `finfo` {#max.dtype.finfo}
> class max.dtype.finfo(dtype)
Numerical properties of a floating point `max.dtype.DType`.
This is modeled after `torch.finfo`, providing `bits`, `eps`,
`max`, `min`, `tiny`, `smallest_normal`, and `dtype`
attributes for every MAX float dtype—including bfloat16, float8, and
float4 types that numpy cannot represent.
**Parameters:**
dtype ([DType](#max.dtype.DType)) – A floating-point `DType` to query.
**Raises:**
[TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If dtype is not a floating-point type.
### `bits` {#max.dtype.finfo.bits}
> bits: [int](https://docs.python.org/3/library/functions.html#int)
### `dtype` {#max.dtype.finfo.dtype}
> dtype: [DType](#max.dtype.DType)
### `eps` {#max.dtype.finfo.eps}
> eps: [float](https://docs.python.org/3/library/functions.html#float)
### `max` {#max.dtype.finfo.max}
> max: [float](https://docs.python.org/3/library/functions.html#float)
### `min` {#max.dtype.finfo.min}
> min: [float](https://docs.python.org/3/library/functions.html#float)
### `smallest_normal` {#max.dtype.finfo.smallest_normal}
> property smallest\_normal: [float](https://docs.python.org/3/library/functions.html#float)
Alias for `tiny` (`torch.finfo` compatibility).
### `tiny` {#max.dtype.finfo.tiny}
> tiny: [float](https://docs.python.org/3/library/functions.html#float)
---
## engine
The APIs in this module allow you to run inference with MAX Engine—a graph
compiler and runtime that accelerates your AI models on a wide variety of
hardware.
## `InferenceSession` {#max.engine.InferenceSession}
> class max.engine.InferenceSession(devices, num\_threads=None, \*, custom\_extensions=None)
Manages an inference session in which you can load and run models.
You need an instance of this to load a model as a [`Model`](#max.engine.Model) object.
For example:
```python
session = engine.InferenceSession(devices=[CPU()])
model_path = Path('bert-base-uncased')
model = session.load(model_path)
```
### `devices` {#max.engine.InferenceSession.devices}
> property devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](driver.md#max.driver.Device)]
A list of available devices.
### `gpu_profiling()` {#max.engine.InferenceSession.gpu_profiling}
> gpu\_profiling(mode)
Enables GPU profiling instrumentation for the session.
This enables GPU profiling instrumentation that works with NVIDIA
Nsight Systems and Nsight Compute. When enabled, the runtime adds CUDA
driver calls and NVTX markers that allow profiling tools to correlate
GPU kernel executions with host-side code.
For example, to enable detailed profiling for Nsight Systems analysis,
call `gpu_profiling()` before `load()`:
```python
from max.engine import InferenceSession, GPUProfilingMode
from max.driver import Accelerator
session = InferenceSession(devices=[Accelerator()])
session.gpu_profiling(GPUProfilingMode.DETAILED)
model = session.load(my_graph)
```
Then run it with `nsys`:
```bash
nsys profile --trace=cuda,nvtx python example.py
```
Or, instead of calling `session.gpu_profiling()` in the code, you can
set the `MODULAR_ENABLE_PROFILING` environment variable when you call
`nsys profile`:
```bash
MODULAR_ENABLE_PROFILING=detailed nsys profile --trace=cuda,nvtx python script.py
```
Beware that `gpu_profiling()` overrides the
`MODULAR_ENABLE_PROFILING` environment variable if also used.
:::note Note
Profiling instrumentation adds runtime overhead and should be
disabled for production deployments.
:::
**Parameters:**
mode ([GPUProfilingMode](#max.engine.GPUProfilingMode)) –
The profiling mode to set. One of:
* [`GPUProfilingMode.OFF`](#max.engine.GPUProfilingMode.OFF): Disable profiling (default).
* [`GPUProfilingMode.ON`](#max.engine.GPUProfilingMode.ON): Enable basic profiling with
NVTX markers for kernel correlation.
* [`GPUProfilingMode.DETAILED`](#max.engine.GPUProfilingMode.DETAILED): Enable detailed profiling
with additional Python-level NVTX markers.
**Return type:**
None
:::note See also
* [GPU profiling with Nsight Systems](/max/gpu-system-profiling)
:::
### `load()` {#max.engine.InferenceSession.load}
> load(model, \*, custom\_extensions=None, weights\_registry=None)
Loads a trained model and compiles it for inference.
**Parameters:**
* model ([str](https://docs.python.org/3/library/stdtypes.html#str) | Path | [Graph](graph/Graph.md#max.graph.Graph)) – Path to a model.
* custom\_extensions (CustomExtensionsType | None) – The extensions to load for the model.
Supports paths to .mojopkg custom ops.
* weights\_registry (Mapping\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](driver.md#max.driver.DLPackArray)] | None) – A mapping from names of model weights’ names to
their values. The values are currently expected to be dlpack
arrays. If an array is a read-only numpy array, the user must
ensure that its lifetime extends beyond the lifetime of the model.
**Returns:**
The loaded model, compiled and ready to execute.
**Raises:**
[RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the path provided is invalid.
**Return type:**
[Model](#max.engine.Model)
### `set_mojo_assert_level()` {#max.engine.InferenceSession.set_mojo_assert_level}
> set\_mojo\_assert\_level(level)
Sets which mojo asserts are kept in the compiled model.
**Parameters:**
level (AssertLevel)
**Return type:**
None
### `set_mojo_log_level()` {#max.engine.InferenceSession.set_mojo_log_level}
> set\_mojo\_log\_level(level)
Sets the verbosity of mojo logging in the compiled model.
### `set_split_k_reduction_precision()` {#max.engine.InferenceSession.set_split_k_reduction_precision}
> set\_split\_k\_reduction\_precision(precision)
Sets the accumulation precision for split k reductions in large matmuls.
### `use_old_top_k_kernel()` {#max.engine.InferenceSession.use_old_top_k_kernel}
> use\_old\_top\_k\_kernel(mode)
Enables the old top-k kernel.
Default is to use the new top-k kernel to keep it consistent with
max/kernels/src/nn/topk.mojo
**Parameters:**
mode ([str](https://docs.python.org/3/library/stdtypes.html#str)) – String to enable/disable. Accepts “false”, “off”, “no”, “0”
to disable, any other value to enable.
**Return type:**
None
## `Model` {#max.engine.Model}
> class max.engine.Model
A loaded model that you can execute.
Do not instantiate this class directly. Instead, create it with
[`InferenceSession`](#max.engine.InferenceSession).
### `__call__()` {#max.engine.Model.__call}
> \_\_call\_\_(\*args, \*\*kwargs)
Call self as a function.
### `capture()` {#max.engine.Model.capture}
> capture(\*inputs)
Capture execution into a device graph keyed by input shapes/dtypes.
Capture is best-effort and model-dependent. If the model issues
capture-unsafe operations (for example, host-device synchronization),
graph capture may fail. Callers should choose capture-safe execution paths.
### `input_metadata` {#max.engine.Model.input_metadata}
> property input\_metadata
Metadata about the model’s input tensors, as a list of
[`TensorSpec`](#max.engine.TensorSpec) objects.
For example, you can print the input tensor names, shapes, and dtypes:
```python
for tensor in model.input_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
```
### `output_metadata` {#max.engine.Model.output_metadata}
> property output\_metadata
Metadata about the model’s output tensors, as a list of
[`TensorSpec`](#max.engine.TensorSpec) objects.
For example, you can print the output tensor names, shapes, and dtypes:
```python
for tensor in model.output_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
```
### `replay()` {#max.engine.Model.replay}
> replay(\*inputs)
Replay the captured device graph for these inputs.
## `GPUProfilingMode` {#max.engine.GPUProfilingMode}
> class max.engine.GPUProfilingMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
The supported modes for GPU profiling.
GPU profiling modes control the level of instrumentation when profiling
MAX applications with NVIDIA Nsight Systems or Nsight Compute. Higher
levels provide more detail but may introduce additional overhead.
:::note See also
[`InferenceSession.gpu_profiling()`](#max.engine.InferenceSession.gpu_profiling): Method to set the profiling mode.
:::
### `DETAILED` {#max.engine.GPUProfilingMode.DETAILED}
> DETAILED = 'detailed'
Enable detailed GPU profiling with additional NVTX markers
from Python code. This mode provides the most visibility into
which Python operations correspond to which GPU kernels, but
has the highest overhead.
### `OFF` {#max.engine.GPUProfilingMode.OFF}
> OFF = 'off'
Disable GPU profiling instrumentation. This is the default mode
and incurs no profiling overhead.
### `ON` {#max.engine.GPUProfilingMode.ON}
> ON = 'on'
Enable basic GPU profiling. Adds CUDA driver calls and NVTX
markers for correlating kernel executions with host-side code.
## `LogLevel` {#max.engine.LogLevel}
> class max.engine.LogLevel(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
The LogLevel specifies the log level used by the Mojo Ops.
### `CRITICAL` {#max.engine.LogLevel.CRITICAL}
> CRITICAL = 'critical'
### `DEBUG` {#max.engine.LogLevel.DEBUG}
> DEBUG = 'debug'
### `ERROR` {#max.engine.LogLevel.ERROR}
> ERROR = 'error'
### `INFO` {#max.engine.LogLevel.INFO}
> INFO = 'info'
### `NOTSET` {#max.engine.LogLevel.NOTSET}
> NOTSET = 'notset'
### `TRACE` {#max.engine.LogLevel.TRACE}
> TRACE = 'trace'
### `WARNING` {#max.engine.LogLevel.WARNING}
> WARNING = 'warning'
## `TensorSpec` {#max.engine.TensorSpec}
> class max.engine.TensorSpec
Defines the properties of a tensor, including its name, shape and
data type.
For usage examples, see [`Model.input_metadata`](#max.engine.Model.input_metadata).
### `dtype` {#max.engine.TensorSpec.dtype}
> property dtype
A tensor data type.
### `name` {#max.engine.TensorSpec.name}
> property name
A tensor name.
### `shape` {#max.engine.TensorSpec.shape}
> property shape
The shape of the tensor as a list of integers.
If a dimension size is unknown/dynamic (such as the batch size), its
value is `None`.
## `CustomExtensionsType` {#max.engine.CustomExtensionsType}
> max.engine.CustomExtensionsType = collections.abc.Sequence\[str | pathlib.\_local.Path] | str | pathlib.\_local.Path
Represent a PEP 604 union type
E.g. for int | str
---
## entrypoints
## `LLM` {#max.entrypoints.llm.LLM}
> class max.entrypoints.llm.LLM(pipeline\_config)
A high level interface for interacting with LLMs.
### `generate()` {#max.entrypoints.llm.LLM.generate}
> generate(prompts, max\_new\_tokens=100, use\_tqdm=True)
Generates text completions for the given prompts.
This method is thread safe and may be used on the same LLM instance
from multiple threads concurrently with no external synchronization.
**Parameters:**
* prompts ([str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) – The input string or list of strings to generate completions for.
* max\_new\_tokens ([int](https://docs.python.org/3/library/functions.html#int) | None) – The maximum number of tokens to generate in the response.
* use\_tqdm ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to display a progress bar during generation.
**Returns:**
A list of generated text completions corresponding to each input prompt.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If prompts is empty or contains invalid data.
* [RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the model fails to generate completions.
---
## functional
Provides functional APIs for tensor operations.
This module provides functional-style tensor operations that work seamlessly
with both MAX Graph construction and eager Tensor execution. All operations
are wrapped versions of the core graph operations that automatically handle
different execution contexts.
These operations can be used in both graph construction and eager execution.
## `CustomExtensionType` {#max.functional.CustomExtensionType}
> max.functional.CustomExtensionType: [TypeAlias](https://docs.python.org/3/library/typing.html#typing.TypeAlias) = str | pathlib.\_local.Path
Type alias for custom extensions paths, matching `engine.CustomExtensionsType`.
## `abs()` {#max.functional.abs}
> max.functional.abs(x)
Computes the absolute value element-wise.
See [`max.graph.ops.abs()`](graph/ops.md#max.graph.ops.abs) for details.
## `add()` {#max.functional.add}
> max.functional.add(lhs, rhs)
Adds two tensors element-wise.
See [`max.graph.ops.add()`](graph/ops.md#max.graph.ops.add) for details.
## `allreduce_sum()` {#max.functional.allreduce_sum}
> max.functional.allreduce\_sum(inputs, signal\_buffers)
Sum values from multiple devices.
See `max.graph.ops.allreduce.sum()` for details.
## `argmax()` {#max.functional.argmax}
> max.functional.argmax(x, axis=-1)
Returns the indices of the maximum values along an axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor.
* axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to find the maximum indices. If None,
finds the index of the maximum across all elements (flattened).
**Returns:**
A tensor containing the indices of the maximum values.
## `argmin()` {#max.functional.argmin}
> max.functional.argmin(x, axis=-1)
Returns the indices of the minimum values along an axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor.
* axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to find the minimum indices. If None,
finds the index of the minimum across all elements (flattened).
**Returns:**
A tensor containing the indices of the minimum values.
## `argsort()` {#max.functional.argsort}
> max.functional.argsort(x, ascending=True)
Returns the indices that would sort a tensor along an axis.
See [`max.graph.ops.argsort()`](graph/ops.md#max.graph.ops.argsort) for details.
## `as_interleaved_complex()` {#max.functional.as_interleaved_complex}
> max.functional.as\_interleaved\_complex(x)
Converts a tensor to interleaved complex representation.
See [`max.graph.ops.as_interleaved_complex()`](graph/ops.md#max.graph.ops.as_interleaved_complex) for details.
## `atanh()` {#max.functional.atanh}
> max.functional.atanh(x)
Computes the inverse hyperbolic tangent element-wise.
See [`max.graph.ops.atanh()`](graph/ops.md#max.graph.ops.atanh) for details.
## `band_part()` {#max.functional.band_part}
> max.functional.band\_part(x, num\_lower=None, num\_upper=None, exclude=False)
Copies a tensor setting everything outside a central band to zero.
See [`max.graph.ops.band_part()`](graph/ops.md#max.graph.ops.band_part) for details.
## `broadcast_to()` {#max.functional.broadcast_to}
> max.functional.broadcast\_to(x, shape, out\_dims=None)
Broadcasts a tensor to a new shape.
See [`max.graph.ops.broadcast_to()`](graph/ops.md#max.graph.ops.broadcast_to) for details.
## `buffer_store()` {#max.functional.buffer_store}
> max.functional.buffer\_store(destination, source)
Sets a tensor buffer to new values.
See [`max.graph.ops.buffer_store()`](graph/ops.md#max.graph.ops.buffer_store) for details.
## `buffer_store_slice()` {#max.functional.buffer_store_slice}
> max.functional.buffer\_store\_slice(destination, source, indices)
Sets a slice of a tensor buffer to new values.
See [`max.graph.ops.buffer_store_slice()`](graph/ops.md#max.graph.ops.buffer_store_slice) for details.
## `cast()` {#max.functional.cast}
> max.functional.cast(x, dtype)
Casts a tensor to a different data type.
See [`max.graph.ops.cast()`](graph/ops.md#max.graph.ops.cast) for details.
## `chunk()` {#max.functional.chunk}
> max.functional.chunk(x, chunks, axis=0)
Splits a tensor into chunks along a dimension.
See [`max.graph.ops.chunk()`](graph/ops.md#max.graph.ops.chunk) for details.
## `complex_mul()` {#max.functional.complex_mul}
> max.functional.complex\_mul(lhs, rhs)
Multiply two complex-valued tensors.
See `max.graph.ops.complex.mul()` for details.
## `concat()` {#max.functional.concat}
> max.functional.concat(original\_vals, axis=0)
Concatenates a list of tensors along an axis.
See [`max.graph.ops.concat()`](graph/ops.md#max.graph.ops.concat) for details.
## `constant()` {#max.functional.constant}
> max.functional.constant(value, dtype=None, device=None)
Creates a constant tensor.
See [`max.graph.ops.constant()`](graph/ops.md#max.graph.ops.constant) for details.
## `constant_external()` {#max.functional.constant_external}
> max.functional.constant\_external(name, type)
Creates a constant tensor from external data.
See [`max.graph.ops.constant_external()`](graph/ops.md#max.graph.ops.constant_external) for details.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* type ([TensorType](graph/type.md#max.graph.type.TensorType))
## `cos()` {#max.functional.cos}
> max.functional.cos(x)
Computes the cosine element-wise.
See [`max.graph.ops.cos()`](graph/ops.md#max.graph.ops.cos) for details.
## `cumsum()` {#max.functional.cumsum}
> max.functional.cumsum(x, axis=-1, exclusive=False, reverse=False)
Computes the cumulative sum along an axis.
See [`max.graph.ops.cumsum()`](graph/ops.md#max.graph.ops.cumsum) for details.
## `custom()` {#max.functional.custom}
> max.functional.custom(name, device, values, out\_types, parameters=None, custom\_extensions=None)
Applies a custom operation with optional custom extension loading.
Creates a node to execute a custom graph operation. The custom op should be
registered by annotating a Mojo function with the `@compiler.register`
decorator.
This function extends [`max.graph.ops.custom()`](graph/ops.md#max.graph.ops.custom) with automatic loading
of custom extension libraries, eliminating the need to manually import
kernels before use.
**Example:**
```python
from max import functional as F, Tensor
from max.dtype import DType
from max.driver import CPU
x = Tensor.full([10], 10, dtype=DType.float32, device=CPU())
y = Tensor.ones([10], dtype=DType.float32, device=CPU())
result = F.custom(
"vector_sum",
device=x.device,
values=[x, y],
out_types=[x.type],
custom_extensions="ops.mojopkg"
)[0]
```
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`.
* device ([driver.Device](driver.md#max.driver.Device) | [DeviceRef](graph/ops.md#max.graph.ops.DeviceRef)) – Device that the op is assigned to. This becomes a `target`
parameter to the kernel.
* values (Sequence\[[Value](graph/Value.md#max.graph.Value)\[Any]]) – The op function’s arguments.
* out\_types (Sequence\[[Type](graph/type.md#max.graph.type.Type)\[Any]]) – The list of op function’s return types.
* parameters (Mapping\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel.
* custom\_extensions (CustomExtensionsType | None) – Paths to custom extension libraries (`.mojopkg`
files or Mojo source directories). Extensions are automatically
loaded into the current graph if not already present.
**Returns:**
Symbolic values representing the outputs of the op in the graph.
These correspond 1:1 with the types passed as `out_types`.
:::note See also
[`max.graph.ops.custom()`](graph/ops.md#max.graph.ops.custom): The underlying graph operation.
[`inplace_custom()`](#max.functional.inplace_custom): For in-place custom operations.
:::
## `div()` {#max.functional.div}
> max.functional.div(lhs, rhs)
Divides two tensors element-wise.
See [`max.graph.ops.div()`](graph/ops.md#max.graph.ops.div) for details.
## `erf()` {#max.functional.erf}
> max.functional.erf(x)
Computes the error function element-wise.
See [`max.graph.ops.erf()`](graph/ops.md#max.graph.ops.erf) for details.
## `exp()` {#max.functional.exp}
> max.functional.exp(x)
Computes the exponential element-wise.
See [`max.graph.ops.exp()`](graph/ops.md#max.graph.ops.exp) for details.
## `flatten()` {#max.functional.flatten}
> max.functional.flatten(x, start\_dim=0, end\_dim=-1)
Flattens a tensor.
See [`max.graph.ops.flatten()`](graph/ops.md#max.graph.ops.flatten) for details.
## `floor()` {#max.functional.floor}
> max.functional.floor(x)
Computes the floor element-wise.
See [`max.graph.ops.floor()`](graph/ops.md#max.graph.ops.floor) for details.
## `functional()` {#max.functional.functional}
> max.functional.functional(op)
Decorator that converts a graph operation to support multiple tensor
types.
**Parameters:**
op ([Callable](graph/ops.md#max.graph.ops.Callable)\[\[...], [Any](https://docs.python.org/3/library/typing.html#typing.Any)])
## `gather()` {#max.functional.gather}
> max.functional.gather(input, indices, axis)
Gathers values along an axis specified by indices.
See [`max.graph.ops.gather()`](graph/ops.md#max.graph.ops.gather) for details.
## `gelu()` {#max.functional.gelu}
> max.functional.gelu(x, approximate='none')
Applies the Gaussian Error Linear Unit (GELU) activation.
See [`max.graph.ops.gelu()`](graph/ops.md#max.graph.ops.gelu) for details.
**Parameters:**
* x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue))
* approximate ([str](https://docs.python.org/3/library/stdtypes.html#str))
## `greater()` {#max.functional.greater}
> max.functional.greater(lhs, rhs)
Computes element-wise greater-than comparison.
See [`max.graph.ops.greater()`](graph/ops.md#max.graph.ops.greater) for details.
## `hann_window()` {#max.functional.hann_window}
> max.functional.hann\_window(window\_length, device, periodic=True, dtype=float32)
Creates a Hann window.
See [`max.graph.ops.hann_window()`](graph/ops.md#max.graph.ops.hann_window) for details.
## `inplace_custom()` {#max.functional.inplace_custom}
> max.functional.inplace\_custom(name, device, values, out\_types=None, parameters=None, custom\_extensions=None)
Applies an in-place custom operation with optional custom extension loading.
Creates a node to execute an in-place custom graph operation. The custom op
should be registered by annotating a Mojo function with the
`@compiler.register` decorator.
This function extends [`max.graph.ops.inplace_custom()`](graph/ops.md#max.graph.ops.inplace_custom) with automatic
loading of custom extension libraries, eliminating the need to manually
import kernels before use.
**Example:**
```python
from max import functional as F, Tensor
from max.dtype import DType
from max.driver import CPU
# Create a buffer for in-place modification
data = Tensor.zeros([10], dtype=DType.float32, device=CPU())
# Use in-place custom op with inline extension loading
F.inplace_custom(
"my_inplace_op",
device=data.device,
values=[data],
custom_extensions="ops.mojopkg"
)
```
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`.
* device ([driver.Device](driver.md#max.driver.Device) | [DeviceRef](graph/ops.md#max.graph.ops.DeviceRef)) – Device that the op is assigned to. This becomes a `target`
parameter to the kernel.
* values (Sequence\[[Value](graph/Value.md#max.graph.Value)\[Any]]) – The op function’s arguments. At least one must be a
`BufferValue` or `_OpaqueValue`.
* out\_types (Sequence\[[Type](graph/type.md#max.graph.type.Type)\[Any]] | None) – The list of op function’s return types. Can be None if the
operation has no outputs.
* parameters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel.
* custom\_extensions (CustomExtensionsType | None) – Paths to custom extension libraries (`.mojopkg`
files or Mojo source directories). Extensions are automatically
loaded into the current graph if not already present.
**Returns:**
Symbolic values representing the outputs of the op in the graph.
:::note See also
[`max.graph.ops.inplace_custom()`](graph/ops.md#max.graph.ops.inplace_custom): The underlying graph operation.
[`custom()`](#max.functional.custom): For non-in-place custom operations.
:::
## `irfft()` {#max.functional.irfft}
> max.functional.irfft(input\_tensor, n=None, axis=-1, normalization=Normalization.BACKWARD, input\_is\_complex=False, buffer\_size\_mb=512)
Computes the inverse real FFT.
See [`max.graph.ops.irfft()`](graph/ops.md#max.graph.ops.irfft) for details.
## `is_inf()` {#max.functional.is_inf}
> max.functional.is\_inf(x)
Checks for infinite values element-wise.
See [`max.graph.ops.is_inf()`](graph/ops.md#max.graph.ops.is_inf) for details.
## `is_nan()` {#max.functional.is_nan}
> max.functional.is\_nan(x)
Checks for NaN values element-wise.
See [`max.graph.ops.is_nan()`](graph/ops.md#max.graph.ops.is_nan) for details.
## `lazy()` {#max.functional.lazy}
> max.functional.lazy()
Context manager for lazy tensor evaluation.
Within this context, tensor operations are recorded but not executed.
Tensors remain unrealized until explicitly awaited via `await tensor.realize`
or until their values are needed (e.g., by calling `.item()`).
This is particularly useful for creating tensors which may not ever
be used. Lazy tensors that aren’t used will never allocate memory or perform
operations.
```python
from max import functional as F
from max.tensor import Tensor
from max.nn import Linear
with F.lazy():
model = Linear(2, 3)
print(model) # Lazy weights not initialized
# Executing the model would be fine! The weights would be created
# on first use.
# output = model(Tensor.ones([5, 2]))
# Load pretrained weights, never creating the original random weights
weights = {
"weight": Tensor.zeros([3, 2]),
"bias": Tensor.zeros([3]),
}
model.load_state_dict(weights)
```
## `log()` {#max.functional.log}
> max.functional.log(x)
Computes the natural logarithm element-wise.
See [`max.graph.ops.log()`](graph/ops.md#max.graph.ops.log) for details.
## `logsoftmax()` {#max.functional.logsoftmax}
> max.functional.logsoftmax(value, axis=-1)
Applies the log softmax function.
See [`max.graph.ops.logsoftmax()`](graph/ops.md#max.graph.ops.logsoftmax) for details.
## `masked_scatter()` {#max.functional.masked_scatter}
> max.functional.masked\_scatter(input, mask, updates, out\_dim)
Scatters values according to a mask.
See [`max.graph.ops.masked_scatter()`](graph/ops.md#max.graph.ops.masked_scatter) for details.
## `max()` {#max.functional.max}
> max.functional.max(x, y=None, /, axis=-1)
Returns the maximum values along an axis, or elementwise maximum of two tensors.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor.
* y (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – Optional second tensor for elementwise maximum.
* axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the maximum (only for reduction).
If None, computes the maximum across all elements (flattened).
## `mean()` {#max.functional.mean}
> max.functional.mean(x, axis=-1)
Computes the mean along specified axes.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor.
* axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the mean. If None,
computes the mean across all elements (flattened).
## `min()` {#max.functional.min}
> max.functional.min(x, y=None, /, axis=-1)
Returns the minimum values along an axis, or elementwise minimum of two tensors.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor.
* y (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – Optional second tensor for elementwise minimum.
* axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the minimum (only for reduction).
If None, computes the minimum across all elements (flattened).
## `mod()` {#max.functional.mod}
> max.functional.mod(lhs, rhs)
Computes the modulo operation element-wise.
See [`max.graph.ops.mod()`](graph/ops.md#max.graph.ops.mod) for details.
## `mul()` {#max.functional.mul}
> max.functional.mul(lhs, rhs)
Multiplies two tensors element-wise.
See [`max.graph.ops.mul()`](graph/ops.md#max.graph.ops.mul) for details.
## `negate()` {#max.functional.negate}
> max.functional.negate(x)
Negates a tensor element-wise.
See [`max.graph.ops.negate()`](graph/ops.md#max.graph.ops.negate) for details.
## `nonzero()` {#max.functional.nonzero}
> max.functional.nonzero(x, out\_dim)
Returns the indices of non-zero elements.
See [`max.graph.ops.nonzero()`](graph/ops.md#max.graph.ops.nonzero) for details.
## `outer()` {#max.functional.outer}
> max.functional.outer(lhs, rhs)
Computes the outer product of two vectors.
See [`max.graph.ops.outer()`](graph/ops.md#max.graph.ops.outer) for details.
## `pad()` {#max.functional.pad}
> max.functional.pad(input, paddings, mode='constant', value=0)
Pads a tensor.
See [`max.graph.ops.pad()`](graph/ops.md#max.graph.ops.pad) for details.
## `permute()` {#max.functional.permute}
> max.functional.permute(x, dims)
Permutes the dimensions of a tensor.
See [`max.graph.ops.permute()`](graph/ops.md#max.graph.ops.permute) for details.
## `pow()` {#max.functional.pow}
> max.functional.pow(lhs, rhs)
Raises tensor elements to a power.
See [`max.graph.ops.pow()`](graph/ops.md#max.graph.ops.pow) for details.
## `relu()` {#max.functional.relu}
> max.functional.relu(x)
Applies the ReLU activation function.
See [`max.graph.ops.relu()`](graph/ops.md#max.graph.ops.relu) for details.
## `repeat_interleave()` {#max.functional.repeat_interleave}
> max.functional.repeat\_interleave(x, repeats, axis=None, out\_dim=None)
Repeats elements of a tensor.
See [`max.graph.ops.repeat_interleave()`](graph/ops.md#max.graph.ops.repeat_interleave) for details.
## `reshape()` {#max.functional.reshape}
> max.functional.reshape(x, shape)
Reshapes a tensor to a new shape.
See [`max.graph.ops.reshape()`](graph/ops.md#max.graph.ops.reshape) for details.
## `rsqrt()` {#max.functional.rsqrt}
> max.functional.rsqrt(x)
Computes the reciprocal square root element-wise.
See [`max.graph.ops.rsqrt()`](graph/ops.md#max.graph.ops.rsqrt) for details.
## `scatter()` {#max.functional.scatter}
> max.functional.scatter(input, updates, indices, axis=-1)
Scatters values along an axis.
See [`max.graph.ops.scatter()`](graph/ops.md#max.graph.ops.scatter) for details.
## `sigmoid()` {#max.functional.sigmoid}
> max.functional.sigmoid(x)
Applies the sigmoid activation function.
See [`max.graph.ops.sigmoid()`](graph/ops.md#max.graph.ops.sigmoid) for details.
**Parameters:**
x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue))
## `silu()` {#max.functional.silu}
> max.functional.silu(x)
Applies the SiLU (Swish) activation function.
See [`max.graph.ops.silu()`](graph/ops.md#max.graph.ops.silu) for details.
**Parameters:**
x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue))
## `sin()` {#max.functional.sin}
> max.functional.sin(x)
Computes the sine element-wise.
See [`max.graph.ops.sin()`](graph/ops.md#max.graph.ops.sin) for details.
## `slice_tensor()` {#max.functional.slice_tensor}
> max.functional.slice\_tensor(x, indices)
Slices a tensor along specified dimensions.
See [`max.graph.ops.slice_tensor()`](graph/ops.md#max.graph.ops.slice_tensor) for details.
**Parameters:**
* x ([TensorValue](graph/TensorValue.md#max.graph.TensorValue))
* indices (SliceIndices)
## `softmax()` {#max.functional.softmax}
> max.functional.softmax(value, axis=-1)
Applies the softmax function.
See [`max.graph.ops.softmax()`](graph/ops.md#max.graph.ops.softmax) for details.
## `split()` {#max.functional.split}
> max.functional.split(x, split\_size\_or\_sections, axis=0)
Splits a tensor into multiple tensors along a given dimension.
This function supports two modes, matching PyTorch’s behavior:
* If `split_size_or_sections` is an **int**, splits into chunks of that
size (the last chunk may be smaller if the dimension is not evenly
divisible).
* If `split_size_or_sections` is a **list of ints**, splits into chunks
with exactly those sizes (must sum to the dimension size).
```python
from max import functional as F, Tensor
x = Tensor.ones([10, 4])
# Split into chunks of size 3 (last chunk is size 1)
chunks = F.split(x, 3, axis=0) # shapes: [3,4], [3,4], [3,4], [1,4]
# Split into exact sizes
chunks = F.split(x, [2, 3, 5], axis=0) # shapes: [2,4], [3,4], [5,4]
```
**Parameters:**
* x ([Tensor](tensor.md#max.tensor.Tensor) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to split.
* split\_size\_or\_sections ([int](https://docs.python.org/3/library/functions.html#int) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Either an int (chunk size) or a list of ints
(exact sizes for each output tensor).
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension along which to split. Defaults to 0.
## `sqrt()` {#max.functional.sqrt}
> max.functional.sqrt(x)
Computes the square root element-wise.
See [`max.graph.ops.sqrt()`](graph/ops.md#max.graph.ops.sqrt) for details.
## `squeeze()` {#max.functional.squeeze}
> max.functional.squeeze(x, axis)
Removes dimensions of size 1.
See [`max.graph.ops.squeeze()`](graph/ops.md#max.graph.ops.squeeze) for details.
## `stack()` {#max.functional.stack}
> max.functional.stack(values, axis=0)
Stacks tensors along a new dimension.
See [`max.graph.ops.stack()`](graph/ops.md#max.graph.ops.stack) for details.
## `sub()` {#max.functional.sub}
> max.functional.sub(lhs, rhs)
Subtracts two tensors element-wise.
See [`max.graph.ops.sub()`](graph/ops.md#max.graph.ops.sub) for details.
## `sum()` {#max.functional.sum}
> max.functional.sum(x, axis=-1)
Computes the sum along specified axes.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The input tensor.
* axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the sum. If None,
computes the sum across all elements (flattened).
## `tanh()` {#max.functional.tanh}
> max.functional.tanh(x)
Computes the hyperbolic tangent element-wise.
See [`max.graph.ops.tanh()`](graph/ops.md#max.graph.ops.tanh) for details.
## `tile()` {#max.functional.tile}
> max.functional.tile(x, repeats)
Tiles a tensor by repeating it.
See [`max.graph.ops.tile()`](graph/ops.md#max.graph.ops.tile) for details.
## `top_k()` {#max.functional.top_k}
> max.functional.top\_k(input, k, axis=-1)
Returns the k largest elements along an axis.
See [`max.graph.ops.top_k()`](graph/ops.md#max.graph.ops.top_k) for details.
## `transfer_to()` {#max.functional.transfer_to}
> max.functional.transfer\_to(x, device)
Transfers a tensor to a specified device.
See [`max.graph.ops.transfer_to()`](graph/ops.md#max.graph.ops.transfer_to) for details.
## `transpose()` {#max.functional.transpose}
> max.functional.transpose(x, axis\_1, axis\_2)
Transposes a tensor.
See [`max.graph.ops.transpose()`](graph/ops.md#max.graph.ops.transpose) for details.
## `unsqueeze()` {#max.functional.unsqueeze}
> max.functional.unsqueeze(x, axis)
Adds dimensions of size 1.
See [`max.graph.ops.unsqueeze()`](graph/ops.md#max.graph.ops.unsqueeze) for details.
## `where()` {#max.functional.where}
> max.functional.where(condition, x, y)
Selects elements from two tensors based on a condition.
See [`max.graph.ops.where()`](graph/ops.md#max.graph.ops.where) for details.
---
## BufferValue
## `BufferValue` {#max.graph.BufferValue}
> class max.graph.BufferValue(value)
Bases: [`Value`](Value.md#max.graph.Value)\[`BufferType`]
Represents a mutable semantic tensor within a Graph.
**Parameters:**
value ([Value](Value.md#max.graph.Value)\[Any] | \_Value\[mo.BufferType] | HasBufferValue)
### `device` {#max.graph.BufferValue.device}
> property device: [DeviceRef](type.md#max.graph.type.DeviceRef)
Returns the device of the BufferValue.
### `dtype` {#max.graph.BufferValue.dtype}
> property dtype: [DType](../dtype.md#max.dtype.DType)
Returns the tensor data type.
### `from_mlir()` {#max.graph.BufferValue.from_mlir}
> classmethod from\_mlir(value)
Creates a [`BufferValue`](#max.graph.BufferValue) from an MLIR buffer value.
**Parameters:**
value (Value\[BufferType]) – The MLIR buffer value to wrap.
**Return type:**
[BufferValue](#max.graph.BufferValue)
### `print()` {#max.graph.BufferValue.print}
> print(label='debug\_buffer')
Prints detailed information about the buffer.
### `rank` {#max.graph.BufferValue.rank}
> property rank: [int](https://docs.python.org/3/library/functions.html#int)
Returns the rank (number of dims) of the buffer.
### `shape` {#max.graph.BufferValue.shape}
> property shape: [Shape](shape.md#max.graph.shape.Shape)
Returns the shape of the BufferValue.
### `type` {#max.graph.BufferValue.type}
> property type: [BufferType](type.md#max.graph.type.BufferType)
Returns the type of the [`BufferValue`](#max.graph.BufferValue) as a `BufferType`.
---
## Graph
## `Graph` {#max.graph.Graph}
> class max.graph.Graph(name, forward=None, input\_types=(), path=None, \*args, custom\_extensions=\[], kernel\_library=None, module=None, \*\*kwargs)
Represents a single MAX graph.
A Graph is a callable routine in MAX Engine. Like functions, graphs have a
name and signature. Unlike a function, which follows an imperative
programming model, a Graph follows a dataflow programming model, using
lazily-executed, parallel operations instead of sequential instructions.
When you instantiate a graph, you must specify the input shapes as one or
more `TensorType` values. Then, build a sequence of ops and set the
graph output with [`output()`](#max.graph.Graph.output). For example:
```python
from dataclasses import dataclass
import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops
@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray
def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor
linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
```
You can’t call a Graph directly from Python. You must compile it and
execute it with MAX Engine. For more detail, see the tutorial about how to
[build a graph with MAX
Graph](/max/tutorials/get-started-with-max-graph-in-python).
When creating a graph, a global sequence of chains is initialized and stored
in Graph.\_current\_chain. Every side-effecting op, e.g. buffer\_load,
store\_buffer, load\_slice\_buffer, store\_slice\_buffer, will use the current
chain to perform the op and and update Graph.\_current\_chain with a new
chain. Currently, the input/output chains for mutable ops can be used at
most once. The goal of this design choice is to prevent data races.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – A name for the graph.
* forward ([Callable](ops.md#max.graph.ops.Callable)\[..., None | [Value](Value.md#max.graph.Value)\[Any] | Iterable\[[Value](Value.md#max.graph.Value)\[Any]]] | None) – The sequence of graph ops for the forward pass (inference).
* input\_types (Iterable\[[Type](type.md#max.graph.type.Type)\[Any]]) – The data type(s) for the input tensor(s).
* path (Path | None) – The path to a saved graph (internal use only).
* custom\_extensions (Iterable\[Path]) – The extensions to load for the model. Supports paths
to `.mojopkg` or `.mojo` sources with custom ops.
* kernel\_library ([KernelLibrary](KernelLibrary.md#max.graph.KernelLibrary) | None) – Optional pre-built kernel library to use. Defaults to
`None` (a new library is created from `custom_extensions` if
needed).
* module (mlir.Module | None) – Optional existing MLIR module (internal use only). Defaults to
`None`.
### `add_subgraph()` {#max.graph.Graph.add_subgraph}
> add\_subgraph(name, forward=None, input\_types=(), path=None, custom\_extensions=\[], devices=\[])
Creates and adds a subgraph to the current graph.
Creates a new [`Graph`](#max.graph.Graph) instance configured as a subgraph of the current
graph. The subgraph inherits the parent graph’s module and symbolic
parameters. A chain type is automatically appended to the input
types to enable proper operation sequencing within the subgraph.
The created subgraph is marked with special MLIR attributes to identify it
as a subgraph and is registered in the parent graph’s subgraph registry.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name identifier for the subgraph.
* forward ([Callable](ops.md#max.graph.ops.Callable)\[\[...], None | [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None) – The optional callable that defines the sequence of operations
for the subgraph’s forward pass. If provided, the subgraph will be
built immediately using this callable.
* input\_types ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The data types for the subgraph’s input tensors. A chain
type will be automatically added to these input types.
* path (Path | None) – The optional path to a saved subgraph definition to load from
disk instead of creating a new one.
* custom\_extensions ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Path]) – The list of paths to custom operation libraries
to load for the subgraph. Supports `.mojopkg` files and Mojo
source directories.
* devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](type.md#max.graph.type.DeviceRef)]) – The list of devices this subgraph is meant to use.
**Return type:**
[Graph](#max.graph.Graph)
### `add_weight()` {#max.graph.Graph.add_weight}
> add\_weight(weight, force\_initial\_weight\_on\_host=True)
Adds a weight to the graph.
If the weight is in the graph already, return the existing value.
**Parameters:**
* weight ([Weight](Weight.md#max.graph.Weight)) – The weight to add to the graph.
* force\_initial\_weight\_on\_host ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, then forces weights
to initially be allocated on host before being moved to
the indicated device. This is needed as a stop gap
until we have a more fleshed out ownership model of
external constants.
**Returns:**
A [`TensorValue`](TensorValue.md#max.graph.TensorValue) that contains this weight.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a weight with the same name already exists in the graph.
### `always_ready_chain` {#max.graph.Graph.always_ready_chain}
> property always\_ready\_chain: \_ChainValue
A graph-global, immutable chain that is always ready.
Created once per graph and never advanced/merged by the graph itself.
Use it for operations that are safe to schedule without threading
per-device ordering (e.g., host→device transfers for staging).
### `current` {#max.graph.Graph.current}
> current
### `device_chains` {#max.graph.Graph.device_chains}
> device\_chains: \_DeviceChainMap
### `inputs` {#max.graph.Graph.inputs}
> property inputs: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
The input values of the graph.
### `kernel_libraries_paths` {#max.graph.Graph.kernel_libraries_paths}
> property kernel\_libraries\_paths: [list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]
Returns the list of extra kernel libraries paths for the custom ops.
### `output()` {#max.graph.Graph.output}
> output(\*outputs)
Sets the output nodes of the [`Graph`](#max.graph.Graph).
### `output_types` {#max.graph.Graph.output_types}
> property output\_types: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
View of the types of the graph output terminator.
---
## KernelLibrary
## `KernelLibrary` {#max.graph.KernelLibrary}
> class max.graph.KernelLibrary(paths=())
Manages custom kernel libraries and operations for a graph.
A kernel library provides access to custom operations and kernels that can
be loaded from various sources including Mojo binary packages (`.mojopkg`)
and Mojo source directories. The library handles verification and registration
of custom operations within the MLIR context.
**Parameters:**
paths (Iterable\[Path])
### `add_path()` {#max.graph.KernelLibrary.add_path}
> add\_path(path)
Adds a kernel library path to the analysis.
**Parameters:**
path (Path) – The `Path` to the kernel library to be added to the
current analysis.
**Return type:**
None
### `library_paths()` {#max.graph.KernelLibrary.library_paths}
> library\_paths()
Returns the list of kernel library paths.
**Returns:**
A list of `Path` objects representing the currently loaded
kernel library paths.
### `load_paths()` {#max.graph.KernelLibrary.load_paths}
> load\_paths(custom\_extensions)
Loads custom operations from provided library paths.
Performs additional “smart” library loading logic for custom operation
libraries in additional formats. The loading logic supports the
following formats:
* Compiled Mojo binary packages with `.mojopkg` extension
* Mojo source directory with custom operations
The loaded libraries are added to the current kernel library.
**Parameters:**
custom\_extensions ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Path]) – The file paths to the custom operation libraries.
**Return type:**
None
### `verify_custom_op()` {#max.graph.KernelLibrary.verify_custom_op}
> verify\_custom\_op(custom\_op)
Verifies that a custom operation is valid within the current context.
**Parameters:**
custom\_op (Operation) – The `mlir.Operation` to be verified against the
current kernel library analysis.
**Return type:**
None
---
## TensorValue
## `TensorValue` {#max.graph.TensorValue}
> class max.graph.TensorValue(value)
Bases: [`Value`](Value.md#max.graph.Value)\[`TensorType`]
Represents a value semantic tensor within a [`Graph`](Graph.md#max.graph.Graph). It provides
various methods and properties to manipulate and query tensor attributes
such as [`shape`](shape.md#module-max.graph.shape), data type ([`dtype`](#max.graph.TensorValue.dtype)), device placement ([`device`](#max.graph.TensorValue.device)), and more.
The following example demonstrates how to create and manipulate tensor values in a graph:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("tensor_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor properties
print(f"Shape: {tensor.shape}") # Output: [2, 2]
print(f"Data type: {tensor.dtype}") # Output: DType.float32
# Perform operations on the tensor
transposed = tensor.T
doubled = tensor * 2
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Transposed shape: {transposed.shape}") # Output: [2, 2]
```
**Parameters:**
value (TensorValueLike)
### `T` {#max.graph.TensorValue.T}
> property T: [TensorValue](#max.graph.TensorValue)
Returns the transposed tensor.
[`T`](#max.graph.TensorValue.T) is the shorthand notation for transposing.
For more information, see [`transpose()`](#max.graph.TensorValue.transpose).
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with swapped dimensions.
### `argmax()` {#max.graph.TensorValue.argmax}
> argmax(axis=-1)
Reduces the tensor using an argmax operation along `axis`.
When the result is ambiguous ie. there are multiple maxima,
selects one index arbitrarily.
```python
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("argmax_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Argmax along axis 1 (last dimension of each row)
indices = x.argmax(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Argmax shape: {indices.shape}") # [2, 1]
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) of dtype `DType.int64` with the same rank as the input,
and the same shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `broadcast_to()` {#max.graph.TensorValue.broadcast_to}
> broadcast\_to(shape)
Broadcasts the tensor to a new shape.
The following example demonstrates how to broadcast a tensor to a larger shape:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("broadcast_to_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Broadcast tensor to a 3x2x2 tensor (add a new dimension of size 3)
broadcasted_tensor = tensor.broadcast_to((3, 2, 2))
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Broadcasted shape: {broadcasted_tensor.shape}") # Output: [3, 2, 2]
```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – An iterable of integers or symbolic dimensions.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the broadcasted shape.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `cast()` {#max.graph.TensorValue.cast}
> cast(dtype)
Casts a symbolic tensor to a different data type.
The following example demonstrates how to cast a tensor from one data type to another:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a matrix with float32 values
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("cast_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Cast tensor to integer type
casted_tensor = tensor.cast(DType.int32)
print(f"Original dtype: {tensor.dtype}") # Output: DType.float32
print(f"Casted dtype: {casted_tensor.dtype}") # Output: DType.int32
```
**Parameters:**
dtype ([DType](../dtype.md#max.dtype.DType)) – The target data type (e.g., `DType.int32`, `DType.float64`).
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the casted data type.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `device` {#max.graph.TensorValue.device}
> property device: [DeviceRef](type.md#max.graph.type.DeviceRef)
Returns the device of the TensorValue.
### `dtype` {#max.graph.TensorValue.dtype}
> property dtype: [DType](../dtype.md#max.dtype.DType)
Returns the tensor data type.
The following example demonstrates how to access the data type of a tensor:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a matrix with float32 values
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("dtype_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor data type
print(f"Data type: {tensor.dtype}") # Output: DType.float32
```
### `flatten()` {#max.graph.TensorValue.flatten}
> flatten(start\_dim=0, end\_dim=-1)
Flattens the specified dims of a symbolic tensor.
The number and order of the elements in the tensor is unchanged.
All dimensions from `start_dim` to `end_dim` (inclusive) are merged into a single output dim.
The following example demonstrates how to flatten a multi-dimensional tensor:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("flatten_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Flatten the tensor to a 1D array
flattened_tensor = tensor.flatten()
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Flattened shape: {flattened_tensor.shape}") # Output: [4]
```
**Parameters:**
* start\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – The starting dimension to flatten. Defaults to `0`.
* end\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – The ending dimension to flatten. Defaults to `-1`.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the flattened dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `from_mlir()` {#max.graph.TensorValue.from_mlir}
> classmethod from\_mlir(value)
Creates a [`TensorValue`](#max.graph.TensorValue) from an MLIR tensor value.
**Parameters:**
value (Value\[TensorType]) – The MLIR tensor value to wrap.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `max()` {#max.graph.TensorValue.max}
> max(axis=-1)
Reduces the tensor using a max operation along `axis`.
```python
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("max_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Max along axis 1 (last dimension of each row)
m = x.max(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Max shape: {m.shape}") # [2, 1]
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same
shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `mean()` {#max.graph.TensorValue.mean}
> mean(axis=-1)
Reduces the tensor using a mean operation along `axis`.
```python
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("mean_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Mean along axis 1 (last dimension of each row)
mu = x.mean(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Mean shape: {mu.shape}") # [2, 1]
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same
shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `min()` {#max.graph.TensorValue.min}
> min(axis=-1)
Reduces the tensor using a min operation along `axis`.
```python
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("min_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Min along axis 1 (last dimension of each row)
mn = x.min(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Min shape: {mn.shape}") # [2, 1]
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same
shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `permute()` {#max.graph.TensorValue.permute}
> permute(dims)
Permutes the tensor’s dimensions based on provided indices.
**Parameters:**
dims ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – A list of integers specifying the new order of dimensions.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with permuted dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `print()` {#max.graph.TensorValue.print}
> print(label='debug\_tensor')
Prints detailed information about the tensor.
**Parameters:**
label ([str](https://docs.python.org/3/library/stdtypes.html#str)) – A string label for the printed output. Defaults to `debug_tensor`.
**Return type:**
None
### `rank` {#max.graph.TensorValue.rank}
> property rank: [int](https://docs.python.org/3/library/functions.html#int)
Returns the rank (number of dims) of the buffer.
The following example demonstrates how to access the rank of a tensor:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix (2-dimensional array)
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("rank_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor rank (number of dimensions)
print(f"Rank: {tensor.rank}") # Output: 2
```
### `rebind()` {#max.graph.TensorValue.rebind}
> rebind(shape, message='')
Rebinds the tensor to a new shape with error handling.
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as an iterable of integers or symbolic dimensions.
* message ([str](https://docs.python.org/3/library/stdtypes.html#str)) – (optional) A message for logging or debugging.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the updated shape.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `reshape()` {#max.graph.TensorValue.reshape}
> reshape(shape)
Creates a new tensor with the same data but reshaped.
The following example demonstrates how to reshape a tensor to change its dimensions:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("reshape_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Reshape tensor to a 1x4 matrix
reshaped_tensor = tensor.reshape((1, 4))
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Reshaped shape: {reshaped_tensor.shape}") # Output: [1, 4]
```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as an iterable of integers or symbolic dimensions.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with the reshaped dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `shape` {#max.graph.TensorValue.shape}
> property shape: [Shape](shape.md#max.graph.shape.Shape)
Returns the shape of the [`TensorValue`](#max.graph.TensorValue).
The following example demonstrates how to access the shape of a tensor:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("shape_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor shape
print(f"Shape: {tensor.shape}") # Shape: [Dim(2), Dim(2)]
```
### `stdev()` {#max.graph.TensorValue.stdev}
> stdev(axis=-1)
Reduces the tensor using a standard deviation operation along `axis`.
The standard deviation is computed as the square root of the population
variance along the specified axis.
```python
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("stdev_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Standard deviation along axis 1 (last dimension of each row)
sd = x.stdev(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Stdev shape: {sd.shape}") # [2, 1]
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same
shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `to()` {#max.graph.TensorValue.to}
> to(device)
Transfers the tensor to a specified device without mutation.
The following example demonstrates how to move a tensor from one device to another:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops, DeviceRef
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
with Graph("to_device_example") as graph:
# Create a tensor on the default device
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Move the tensor to a GPU device
gpu_tensor = tensor.to(DeviceRef.GPU())
print(f"Original device: {tensor.device}") # Output depends on default device
print(f"New device: {gpu_tensor.device}") # Output: gpu:0
```
**Parameters:**
device ([DeviceRef](type.md#max.graph.type.DeviceRef)) – A `DeviceRef` object specifying the target device.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) on the specified device.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `transpose()` {#max.graph.TensorValue.transpose}
> transpose(dim\_1, dim\_2)
Swaps two dimensions of the tensor.
The following example demonstrates how to transpose a tensor by swapping its dimensions:
```python
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x3 matrix
matrix = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
with Graph("transpose_demo") as graph:
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Transpose the tensor (swap dimensions 0 and 1)
transposed_tensor = tensor.transpose(dim_1=0, dim_2=1)
print(f"Original shape: {tensor.shape}") # Output: [2, 3]
print(f"Transposed shape: {transposed_tensor.shape}") # Output: [3, 2]
```
**Parameters:**
* dim\_1 ([int](https://docs.python.org/3/library/functions.html#int)) – The first dimension to swap.
* dim\_2 ([int](https://docs.python.org/3/library/functions.html#int)) – The second dimension to swap.
**Returns:**
A new [`TensorValue`](#max.graph.TensorValue) with swapped dimensions.
**Return type:**
[TensorValue](#max.graph.TensorValue)
### `type` {#max.graph.TensorValue.type}
> property type: [TensorType](type.md#max.graph.type.TensorType)
Returns the type of the [`TensorValue`](#max.graph.TensorValue) as a `TensorType`.
### `var()` {#max.graph.TensorValue.var}
> var(axis=-1)
Reduces the tensor using a variance operation along `axis`.
The variance is computed as the mean of squared deviations from the mean
(population variance, i.e., without Bessel’s correction) along the specified axis.
```python
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("var_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Variance along axis 1 (last dimension of each row)
vr = x.var(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Var shape: {vr.shape}") # [2, 1]
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension (e.g., `-1` is the last dimension).
**Returns:**
A [`TensorValue`](#max.graph.TensorValue) with the same rank as the input and the same
shape except along `axis`, which will have size 1.
**Return type:**
[TensorValue](#max.graph.TensorValue)
---
## Value
## `Value` {#max.graph.Value}
> class max.graph.Value
Represents a symbolic value within a Graph.
A Value can represent the output of a node, the arguments of a
Graph (as seen from within its body), and more generally any symbolic
value available within the Graph. Other nodes receive Value
values as inputs to form a computation graph.
A Value may also refer to an existing input or output of a node,
and you can change them, such as by swapping a new Value.
Conceptually, think of a Value as an edge in the dataflow graph,
with the other end being the user of that value.
The following example shows how to work with Values in a graph to create a simple computation:
```python
from max.graph import Graph, ops, Value
from max.dtype import DType
import numpy as np
with Graph("value_example") as graph:
# Create input values
a = ops.constant(np.array([1, 2, 3]), dtype=DType.float32, device=DeviceRef.CPU())
b = ops.constant(np.array([4, 5, 6]), dtype=DType.float32, device=DeviceRef.CPU())
# Use values to perform operations
c = a + b # c is a Value representing the addition
# Demonstrate that the result is a Value
print(f"Type of c: {type(c)}")
print(f"Is c a Value? {isinstance(c, Value)}")
```
Similar to a regular variable, a Value has a data type.
### `buffer` {#max.graph.Value.buffer}
> property buffer: [BufferValue](BufferValue.md#max.graph.BufferValue)
Returns the Value as a [`BufferValue`](BufferValue.md#max.graph.BufferValue).
Raises an exception if the Value is not a BufferValue.
### `from_mlir()` {#max.graph.Value.from_mlir}
> classmethod from\_mlir(value)
Creates a [`Value`](#max.graph.Value) from an MLIR value.
**Parameters:**
value (Value\[MlirType]) – The MLIR value to wrap.
### `opaque` {#max.graph.Value.opaque}
> property opaque: \_OpaqueValue
Returns the Value as an `_OpaqueValue`.
Raises an exception if the Value is not a \_OpaqueValue.
### `tensor` {#max.graph.Value.tensor}
> property tensor: [TensorValue](TensorValue.md#max.graph.TensorValue)
Returns the Value as a [`TensorValue`](TensorValue.md#max.graph.TensorValue).
Raises an exception if the Value is not a TensorValue.
### `to_mlir()` {#max.graph.Value.to_mlir}
> to\_mlir()
Converts the [`Value`](#max.graph.Value) to an MLIR value.
**Return type:**
Value\[MlirType]
### `type` {#max.graph.Value.type}
> property type: [Type](type.md#max.graph.type.Type)\[MlirType]
Returns the type of the [`Value`](#max.graph.Value) as a `Type`.
---
## Weight
## `Weight` {#max.graph.Weight}
> class max.graph.Weight(\*args, \*\*kwargs)
Bases: [`TensorValue`](TensorValue.md#max.graph.TensorValue)
Represents a value in a Graph that can be loaded at a later time.
Weights can be initialized outside of a Graph and are lazily-added to
the parent graph when used. If there is no parent graph when a weight is
used, an error will be raised.
### `align` {#max.graph.Weight.align}
> align: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
### `device` {#max.graph.Weight.device}
> property device: [DeviceRef](type.md#max.graph.type.DeviceRef)
The device where the weight resides.
### `dtype` {#max.graph.Weight.dtype}
> property dtype: [DType](../dtype.md#max.dtype.DType)
The data type of the weight.
### `original_dtype_and_shape` {#max.graph.Weight.original_dtype_and_shape}
> property original\_dtype\_and\_shape: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[DType](../dtype.md#max.dtype.DType), [Shape](shape.md#max.graph.shape.Shape)]
The original dtype and shape of this weight.
This property should be used to store the original weight’s dtype and
shape the quantization encoding forces the weight to be loaded as uint8.
### `quantization_encoding` {#max.graph.Weight.quantization_encoding}
> quantization\_encoding: [QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None)
### `shape` {#max.graph.Weight.shape}
> property shape: [Shape](shape.md#max.graph.shape.Shape)
The shape of the weight.
For sharded weights, returns the shape of the shard. Otherwise,
returns the original weight shape.
### `shard()` {#max.graph.Weight.shard}
> shard(devices)
Creates sharded views of this Weight across multiple devices.
This Weight must have sharding\_strategy defined. The shard objects
returned are also Weight objects, but cannot be sharded further.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
### `shard_idx` {#max.graph.Weight.shard_idx}
> shard\_idx: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
### `sharding_strategy` {#max.graph.Weight.sharding_strategy}
> property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None)
Gets the weight sharding strategy.
---
## dim
Library for graph dimension types.
## `AlgebraicDim` {#max.graph.dim.AlgebraicDim}
> class max.graph.dim.AlgebraicDim(value)
An algebraic tensor dimension to enable expressions over symbolic
dimensions.
That is, any expression over a symbolic dimension returns `AlgebraicDim`.
Furthermore, algebraic dimensions automatically simplify into a canonical
form.
The following example demonstrates how to create and use algebraic dimensions with symbolic values:
```python
from max.graph import AlgebraicDim, Dim
isinstance(Dim("batch") * 5, AlgebraicDim) # Returns True
print(Dim("batch") * 5) # Outputs: batch * 5
-Dim("x") - 4 == -(Dim("x") + 4) # Returns True
```
### `attr` {#max.graph.dim.AlgebraicDim.attr}
> attr: ParamOperatorAttr
### `from_mlir()` {#max.graph.dim.AlgebraicDim.from_mlir}
> static from\_mlir(attr)
Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension.
* attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `parameters` {#max.graph.dim.AlgebraicDim.parameters}
> property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)]
Lists the symbolic dimension names on which this dim depends.
### `to_mlir()` {#max.graph.dim.AlgebraicDim.to_mlir}
> to\_mlir()
Creates an mlir.Attribute representing this dimension.
This is used internally when constructing tensor MLIR types.
**Returns:**
An mlir.Attribute in the context representing the dimension.
**Return type:**
ParamOperatorAttr
## `Dim` {#max.graph.dim.Dim}
> class max.graph.dim.Dim(value)
A tensor dimension.
Tensor dimensions can be one of three types:
* **Static**: Known size
* **Symbolic**: Unknown size but named
* **Algebraic**: Unknown size has an algebraic expression
In most cases, you don’t need to work with a `Dim` directly.
Instead, use conversion constructors:
```python
from max.graph import Dim, TensorType, DeviceRef
tensor_type = TensorType(DType.int64, ("batch", 10), device=DeviceRef.CPU())
```
This creates a tensor type with two dimensions:
* A symbolic “batch” dimension
* A static dimension of size 10
For explicit dimension construction, use the following helpers:
```python
from max.graph import Dim
some_dims = [
SymbolicDim("batch"),
StaticDim(5),
AlgebraicDim(Dim("batch") + 1),
]
```
Constraining tensor dimensions is one important way to improve model
performance. If tensors have unknown dimensions, we can’t optimize them
as aggressively. Symbolic tensors allow the compiler to learn constraints
on a specific dimension (eg. if 2 inputs have the same batch dimension),
but static dims are the easiest to optimize and therefore the easiest to
create and work with.
**Parameters:**
value (DimLike)
### `from_mlir()` {#max.graph.dim.Dim.from_mlir}
> static from\_mlir(attr)
Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension.
* attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `parameters` {#max.graph.dim.Dim.parameters}
> property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)]
Lists the symbolic dimension names on which this dim depends.
### `to_mlir()` {#max.graph.dim.Dim.to_mlir}
> to\_mlir()
Creates an `mlir.Attribute` representing this dimension.
This is used internally when constructing tensor MLIR types.
**Returns:**
An `mlir.Attribute` in the context representing the dimension.
**Return type:**
TypedAttr
## `StaticDim` {#max.graph.dim.StaticDim}
> class max.graph.dim.StaticDim(value)
A static tensor dimension.
Static tensor dimensions will always have exactly the same value,
and are key to good model performance.
The following example shows how static dimensions can be created implicitly:
```python
from max.graph import TensorType
from max.dtype import DType
tensor = TensorType(DType.int64, (4, 5))
```
**Parameters:**
dim ([int](https://docs.python.org/3/library/functions.html#int))
### `dim` {#max.graph.dim.StaticDim.dim}
> dim: [int](https://docs.python.org/3/library/functions.html#int)
The size of the static dimension.
### `from_mlir()` {#max.graph.dim.StaticDim.from_mlir}
> static from\_mlir(attr)
Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension.
* attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `parameters` {#max.graph.dim.StaticDim.parameters}
> property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)]
Lists the symbolic dimension names on which this dim depends.
### `to_mlir()` {#max.graph.dim.StaticDim.to_mlir}
> to\_mlir()
Creates an `mlir.Attribute` representing this dimension.
This is used internally when constructing tensor MLIR types.
**Returns:**
An `mlir.Attribute` in the context representing the dimension.
**Return type:**
IntegerAttr
## `SymbolicDim` {#max.graph.dim.SymbolicDim}
> class max.graph.dim.SymbolicDim(value)
A symbolic tensor dimension.
Symbolic dimensions represent named dimensions in MO tensor types.
Symbolic dimensions don’t have a static value, but they allow a readable
name to understand what’s going on in the model IR better, and they also
allow users to hint to the compiler that two dimensions will have the same
value, which can often allow important speedups.
In tensor type notation:
```default
!mo.tensor<[batch, x, 10], si32]>
```
The first and second dimensions are named `batch` and `x` respectively.
Creating a `SymbolicDim`:
```python
dim = SymbolicDim("name")
```
Using `SymbolicDim` in a `TensorType`:
```python
tensor_type = TensorType(DType.bool, (SymbolicDim("batch"), SymbolicDim("x"), 10))
```
**Parameters:**
name ([str](https://docs.python.org/3/library/stdtypes.html#str))
### `from_mlir()` {#max.graph.dim.SymbolicDim.from_mlir}
> static from\_mlir(attr)
Constructs a dimension from an `mlir.Attribute`.
**Parameters:**
* dim\_attr – The MLIR Attribute object to parse into a dimension.
* attr (TypedAttr)
**Returns:**
The dimension represented by the MLIR Attr value.
**Return type:**
[Dim](#max.graph.dim.Dim)
### `name` {#max.graph.dim.SymbolicDim.name}
> name: [str](https://docs.python.org/3/library/stdtypes.html#str)
The name of the dimension.
### `parameters` {#max.graph.dim.SymbolicDim.parameters}
> property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](#max.graph.dim.SymbolicDim)]
Lists the symbolic dimension names on which this dim depends.
### `to_mlir()` {#max.graph.dim.SymbolicDim.to_mlir}
> to\_mlir()
Creates an `mlir.Attribute` representing this dimension.
This is used internally when constructing tensor MLIR types.
**Returns:**
An `mlir.Attribute` in the context representing the dimension.
**Return type:**
ParamDeclRefAttr
---
## graph (Graph)
APIs to build inference graphs for MAX Engine with Python.
## Classes
* [`BufferValue`](/max/api/python/graph/BufferValue): Represents a mutable semantic tensor within a Graph.
* [`Graph`](/max/api/python/graph/Graph): Represents a graph for MAX Engine.
* [`KernelLibrary`](/max/api/python/graph/KernelLibrary): Represents a library with custom ops.
* [`TensorValue`](/max/api/python/graph/TensorValue): Represents a value semantic tensor within a Graph.
* [`Value`](/max/api/python/graph/Value): Represents a symbolic value within a Graph.
* [`Weight`](/max/api/python/graph/Weight): Represents a weight value in a graph.
## Modules
* [`dim`](/max/api/python/graph/dim): APIs for graph value tensor dimensions.
* [`ops`](/max/api/python/graph/ops): Ops you can add when staging a graph.
* [`quantization`](/max/api/python/graph/quantization): APIs to quantize graph tensors.
* [`shape`](/max/api/python/graph/shape): APIs for graph value tensor shapes.
* [`type`](/max/api/python/graph/type): APIs for graph value types.
* [`weights`](/max/api/python/graph/weights): APIs for loading weights into a graph.
---
## ops
Implements operations used when staging a graph.
This module provides operations for building computational graphs in MAX. These
operations create, transform, and manipulate tensor values within the graph.
You can also use functions in [Graph](/max/api/python/graph/Graph) to add
constant values to your graph with operations like
[constant()](/max/api/python/graph/ops#max.graph.ops.constant).
The [TensorValue](/max/api/python/graph/TensorValue/) type (returned by most
operations) implements various dunder methods to support operations between
TensorValues, such as + for addition, \* for multiplication, and @ for
matrix multiplication. It also provides convenience methods like
[reshape()](/max/api/python/graph/TensorValue/#max.graph.TensorValue.reshape)
and
[flatten()](/max/api/python/graph/TensorValue/#max.graph.TensorValue.flatten).
### `Callable` {#max.graph.ops.Callable}
> class max.graph.ops.Callable
### `DeviceRef` {#max.graph.ops.DeviceRef}
> class max.graph.ops.DeviceRef(device\_type, id=0)
A symbolic device representation.
DeviceRef type representation consists of a DeviceKind and an id. This is a direct
representation of the device attribute in mlir.
The following example demonstrates how to create and use device references:
```python
from max.graph import DeviceRef
gpu_device = DeviceRef.GPU()
print(gpu_device) # Outputs: gpu:0
## Create a CPU device with specific id
cpu_device = DeviceRef.CPU(id=1)
print(cpu_device) # Outputs: cpu:1
```
**Parameters:**
* device\_type ([DeviceKind](type.md#max.graph.type.DeviceKind))
* id ([int](https://docs.python.org/3/library/functions.html#int))
#### `CPU()` {#max.graph.ops.DeviceRef.CPU}
> static CPU(id=0)
Static Method for creating a CPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[DeviceRef](type.md#max.graph.type.DeviceRef)
#### `GPU()` {#max.graph.ops.DeviceRef.GPU}
> static GPU(id=0)
Static Method for creating a GPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
#### `from_mlir()` {#max.graph.ops.DeviceRef.from_mlir}
> static from\_mlir(attr)
Returns a device from an mlir attribute
**Parameters:**
attr (DeviceRefAttr)
**Return type:**
[DeviceRef](type.md#max.graph.type.DeviceRef)
#### `id` {#max.graph.ops.DeviceRef.id}
> id: [int](https://docs.python.org/3/library/functions.html#int)
#### `is_cpu()` {#max.graph.ops.DeviceRef.is_cpu}
> is\_cpu()
Returns true if the device is a CPU device.
### `InterpolationMode` {#max.graph.ops.InterpolationMode}
> class max.graph.ops.InterpolationMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Interpolation modes for image resize operations.
This enum defines the available interpolation methods that can be used
when resizing tensors. Currently only BICUBIC is implemented, with
BILINEAR and NEAREST planned for future support.
#### `BICUBIC` {#max.graph.ops.InterpolationMode.BICUBIC}
> BICUBIC = 'bicubic'
#### `BILINEAR` {#max.graph.ops.InterpolationMode.BILINEAR}
> BILINEAR = 'bilinear'
#### `NEAREST` {#max.graph.ops.InterpolationMode.NEAREST}
> NEAREST = 'nearest'
### `TensorType` {#max.graph.ops.TensorType}
> class max.graph.ops.TensorType(dtype, shape, device, \_layout=None)
A symbolic [`TensorType`](#max.graph.ops.TensorType).
This is not an eager tensor type! This contains no actual data, but
instead represents the type of a value at some point in time during model
execution.
Most internal values in a model will be tensors. This type represents
their element type (`dtype`) and dimensions (`dims`) at a specific point during
model computation. It allows us to do some optimistic optimizations and
shape inference during graph construction, and to provide more detailed
shape information to the compiler for further optimization passes.
The following example shows how to create a tensor type with static dimensions and access its properties:
```python
from max.graph import TensorType
from max.dtype import DType
## Create a tensor type with float32 elements and static dimensions 2x3
tensor_type = TensorType(DType.float32, (2, 3))
print(tensor_type.dtype) # Outputs: DType.float32
print(tensor_type.shape) # Outputs: [2, 3]
```
It can also represent a fully dynamic rank tensor. The presence of dynamic
rank tensors in a graph will often degrade performance dramatically and
prevents many classes of optimizations.
An optional device (`device`) can also be provided to indicate the explicit
device the tensor is associated with.
### `acos()` {#max.graph.ops.acos}
> max.graph.ops.acos(x)
Computes the arccosine (inverse cosine) of the input tensor.
Returns values in the range \[0, π] for inputs in \[-1, 1].
Creates a new op node to compute the elementwise arccosine of a
symbolic tensor and adds it to the graph, returning the symbolic result.
```python
def acos_graph():
input_type = TensorType(dtype=DType.float32, shape=(3,), device=DeviceRef.CPU())
with Graph("acos_graph", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.acos(x)
graph.output(out)
```
**Parameters:**
x ([TensorValue](TensorValue.md#max.graph.TensorValue)) – Input tensor with values in \[-1, 1]. If values are outside this
domain, they will be clamped to the valid range.
**Returns:**
* the same dtype as the input
* the same shape as the input
**Return type:**
Arccosine of the input in radians \[0, π]. The result will have
**Raises:**
* Error – If the symbol doesn’t represent a tensor value.
* Error – If the input is not a floating-point dtype.
### `allgather()` {#max.graph.ops.allgather}
> max.graph.ops.allgather(inputs, signal\_buffers, axis=0)
Collective allgather operation.
This op is a collective op which takes in tensors from different devices and
outputs tensors on different devices.
In particular, this operation will gather the inputs across different
devices and concatenates them along the specified dimension.
The result is then broadcasted back to the same devices that the inputs
came from.
**Parameters:**
* inputs ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – The input tensors to gather.
* signal\_buffers ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue]) – Device buffer values used for synchronization.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension to concatenate the input tensors. Defaults to 0.
**Returns:**
An iterable outputs which all hold the gathered output. Each output
tensor contains the concatenation of all inputs along the specified dimension.
### `argmax()` {#max.graph.ops.argmax}
> max.graph.ops.argmax(x, axis=-1)
Reduces a symbolic tensor using an argmax operation.
When provided with a tensor with all identical elements,
on CPU this will return the first element index in the tensor,
on GPU this will return an arbitrary index.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension. For example, a value of -1 will
compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the argmax operation.
The tensor will have the same rank as the input tensor, and the same
shape except along the `axis` dimension which will have size 1.
### `argmin()` {#max.graph.ops.argmin}
> max.graph.ops.argmin(x, axis=-1)
Reduces a symbolic tensor using an argmin operation.
When provided with a tensor with all identical elements,
on CPU this will return the first element index in the tensor,
on GPU this will return an arbitrary index.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension. For example, a value of -1 will
compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the argmin operation.
The tensor will have the same rank as the input tensor, and the same
shape except along the `axis` dimension which will have size 1.
### `argsort()` {#max.graph.ops.argsort}
> max.graph.ops.argsort(x, ascending=True)
Returns the indices that would sort a tensor.
This function returns the indices that would sort the input tensor along
its first dimension. The returned indices are of type int64.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – Input tensor to be sorted.
* ascending ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True (default), sort in ascending order. If False, sort in
descending order.
**Returns:**
A tensor of indices of the same shape as the input tensor.
### `as_interleaved_complex()` {#max.graph.ops.as_interleaved_complex}
> max.graph.ops.as\_interleaved\_complex(x)
Reshapes the input symbolic tensor as complex from alternating (real, imag).
**Parameters:**
* interleaved – A symbolic tensor representing complex numbers as
alternating pairs of (real, imag) real-valued numbers. Its last
dimension must have an even size.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A symbolic tensor representing the complex-valued tensor, but with the
values pulled out as complex numbers. The result has the same dimensions
for all dimensions except the last dimension, which is halved,
and then a final dimension of size 2 representing the complex value.
### `avg_pool2d()` {#max.graph.ops.avg_pool2d}
> max.graph.ops.avg\_pool2d(input, kernel\_size, stride=1, dilation=1, padding=0, ceil\_mode=False, count\_boundary=True)
Perform a 2D average pooling operation on the input tensor.
This function applies a 2D average pooling operation to the input tensor \[N, H, W, C].
The pooling operation slides a window of size kernel\_size over the input
tensor, and computes the average value within each window.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to perform the pooling operation on.
* kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The size of the sliding blocks.
* stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding blocks in the input dimension.
* dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel elements.
* padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – 0-paddings to be added on both sides of the inputs.
* ceil\_mode ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, use ceil instead of floor to compute the output shape.
* count\_boundary ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, count the padding elements when computing the average.
### `band_part()` {#max.graph.ops.band_part}
> max.graph.ops.band\_part(x, num\_lower=None, num\_upper=None, exclude=False)
Masks out everything except a diagonal band of an input matrix.
Copies a tensor setting everything outside the central diagonal band of the
matrices to zero, where all but the last two axes are effectively batches,
and the last two axes define sub matrices.
Assumes the input has dimensions \[I, J, …, M, N], then the output tensor
has the same shape as the input, and the values are given by
```python
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].
```
with the indicator function:
```python
in_band(m, n) = ((num_lower is None || (m - n) <= num_lower)) &&
(num_upper is None || (n - m) <= num_upper))
```
**Parameters:**
* input – The input to mask out.
* num\_lower ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of diagonal bands to include below the central
diagonal. If None, include the entire lower triangle.
* num\_upper ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of diagonal bands to include above the central
diagonal. If None, include the entire upper triangle.
* exclude ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, invert the selection of elements to mask. Elements
in the band are set to zero.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A symbolic tensor value with the configured selection masked out
to 0 values, and the remaining values copied from the input tensor.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input tensor rank is less than 2, or if num\_lower/num\_upper
are out of bounds for statically known dimensions.
### `broadcast_to()` {#max.graph.ops.broadcast_to}
> max.graph.ops.broadcast\_to(x, shape, out\_dims=None)
Broadcasts a symbolic tensor.
Broadcasts the input tensor to the specified shape.
Dimensions in the input must be one or match the target dimension.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input symbolic tensor to broadcast.
This tensor may not contain any dynamic dimensions.
* shape ([TensorValue](TensorValue.md#max.graph.TensorValue) | [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as a list of dimensions.
Dynamic dimensions are not allowed.
* out\_dims ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) – Output dims used only for tensor-valued shape.
**Returns:**
A symbolic tensor with the same elements as the original tensor, but
in a new shape. Its symbolic shape is the same as `shape`.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – if a tensor-valued shape is passed without out\_dims.
### `buffer_load()` {#max.graph.ops.buffer_load}
> max.graph.ops.buffer\_load(x)
Loads the input buffer into a tensor.
It loads the in-place mutable tensor to an immutable tensor graph value.
This is semantically equivalent to a copy from the mutable tensor x to the
mutable value-semantic tensor output.
**Parameters:**
x ([BufferValue](BufferValue.md#max.graph.BufferValue)) – The buffer to be loaded to a tensor.
**Returns:**
A tensor graph value representing a copy of the buffer loaded.
### `buffer_store()` {#max.graph.ops.buffer_store}
> max.graph.ops.buffer\_store(destination, source)
Stores the input tensor into the in-out buffer.
It stores the immutable input tensor x in the mutable tensor y.
This is semantically equivalent to a copy from x tensor to the y buffer.
**Parameters:**
* x – The tensor to be stored in the buffer.
* y – The buffer to store the tensor in.
* destination ([BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue)
* source (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Return type:**
None
### `buffer_store_slice()` {#max.graph.ops.buffer_store_slice}
> max.graph.ops.buffer\_store\_slice(destination, source, indices)
Stores the input tensor to into a slice in the input buffer.
It stores the immutable input tensor source in the mutable tensor destination.
This is semantically equivalent to a copy from source tensor to a slice in the
destination buffer at index specified by indices.
**Parameters:**
* destination ([BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue) – The buffer to store the tensor in.
* source (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The tensor to be stored in the buffer.
* indices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorValue](TensorValue.md#max.graph.TensorValue) | [int](https://docs.python.org/3/library/functions.html#int) | [slice](https://docs.python.org/3/library/functions.html#slice) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[slice](https://docs.python.org/3/library/functions.html#slice), [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | builtins.ellipsis]) – The index in the buffer where the tensor should be stored
**Return type:**
None
### `call()` {#max.graph.ops.call}
> max.graph.ops.call(graph, \*args, prefix='')
Call a graph with the provided arguments and return its results.
This function invokes a previously defined graph, passing in the provided
arguments and the current chain value, and returns the results.
The body of the graph is ultimately inlined into the caller, so the chain
value is only used for serialization if the subgraph’s body contains an
operation that makes use of it in the first place.
The current advantage of using subgraphs is that it offers a way to improve
compile times for operations that are used repeatedly in a model. As a
secondary benefit, it also makes the IR more readable by allowing control
flow to be expressed in a more natural way.
**Parameters:**
* graph ([Graph](Graph.md#max.graph.Graph)) – The graph to call
* \*args ([Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – Arguments to pass to the called graph
* prefix ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Prefix to add to the names of any weights in the subgraph
**Returns:**
Either a single Value or a list of Values representing the graph outputs
(excluding the chain value which is handled internally)
### `cast()` {#max.graph.ops.cast}
> max.graph.ops.cast(x, dtype)
Casts a symbolic tensor to a different data type.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input tensor to cast.
* dtype ([DType](../dtype.md#max.dtype.DType)) – The target dtype to which the tensor is cast.
**Returns:**
A new symbolic tensor with the same shape as the input and the
specified dtype.
### `chunk()` {#max.graph.ops.chunk}
> max.graph.ops.chunk(x, chunks, axis=0)
Chunk the tensor into an exact number of chunks along the specified dim.
**Example:**
```pycon
>>> a = TensorValue([1, 2, 3, 4, 5])
>>> chunk(a, 2, 0)
[TensorValue([1, 2]), TensorValue([3, 4])]
```
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The tensor to chunk.
* chunks ([int](https://docs.python.org/3/library/functions.html#int)) – The number of chunks to split the tensor into.
chunks must statically evenly divide x.shape\[axis].
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to split the tensor along.
### `concat()` {#max.graph.ops.concat}
> max.graph.ops.concat(original\_vals, axis=0)
Concatenates a list of symbolic tensors along an axis.
Joins multiple tensors along a specified dimension. This operation requires
the functional API since it operates on multiple tensors. All input tensors
must have the same rank and the same size in all dimensions except the
concatenation axis.
```python
import max.functional as F
from max.tensor import Tensor
## Create two 2x2 matrices
a = Tensor.constant([[1, 2], [3, 4]])
b = Tensor.constant([[5, 6], [7, 8]])
## Concatenate along axis 0 (rows) - stacks vertically
vertical = F.concat([a, b], axis=0)
print(f"Concatenated along axis 0: {vertical.shape}")
## Output: Concatenated along axis 0: [Dim(4), Dim(2)]
print(vertical)
## [[1, 2],
## [3, 4],
## [5, 6],
## [7, 8]]
## Concatenate along axis 1 (columns) - joins horizontally
horizontal = F.concat([a, b], axis=1)
print(f"Concatenated along axis 1: {horizontal.shape}")
## Output: Concatenated along axis 1: [Dim(2), Dim(4)]
print(horizontal)
## [[1, 2, 5, 6],
## [3, 4, 7, 8]]
```
**Parameters:**
* original\_vals ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – The list of symbolic tensor values to concatenate. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension other than `axis`.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape. For instance, `concat(vs, -1)`
will concatenate along the last dimension.
**Returns:**
A new symbolic tensor representing the concatenation result. It will
have the same rank as each input tensor, and its dimensions will be the same
as each input tensor’s for each dimension other than axis, which will
have size equal to the sum of all tensor’s size for that dimension.
### `cond()` {#max.graph.ops.cond}
> max.graph.ops.cond(pred, out\_types, then\_fn, else\_fn)
Conditionally execute one of two branches based on a boolean predicate.
Both branches must return the same number and types of values as specified
in `out_types`. Buffer mutations in branches are tracked automatically
through the chain mechanism.
Examples:
1. Basic conditional with return values:
> ```python
> def then_fn():
> return ops.constant(1, DType.int32, device=DeviceRef.CPU())
> def else_fn():
> return ops.constant(0, DType.int32, device=DeviceRef.CPU())
>
> result = ops.cond(
> pred,
> [TensorType(DType.int32, [], device=device)],
> then_fn,
> else_fn
> )
> ```
2. Conditional with buffer mutations:
> ```python
> def then_fn():
> ops.inplace_custom("increment", device=buffer.device, values=[buffer])
> def else_fn():
> ops.inplace_custom("decrement", device=buffer.device, values=[buffer])
>
> ops.cond(pred, None, then_fn, else_fn)
> ```
::
:param pred: Boolean scalar tensor of type `DType.bool` determining branch execution
:param out\_types: Expected output types for both branches. Use [`None`](https://docs.python.org/3/library/constants.html#None) for branches that don’t return values
:param then\_fn: Callable executed when `pred` is True. Must return values matching `out_types` if `out_types` is not [`None`](https://docs.python.org/3/library/constants.html#None)
:param else\_fn: Callable executed when `pred` is False. Must return values matching `out_types` if `out_types` is not [`None`](https://docs.python.org/3/library/constants.html#None)
**Returns:**
List of output values from executed branch. Returns empty list when `out_types`
is [`None`](https://docs.python.org/3/library/constants.html#None)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If branches return different numbers of results or result types
don’t match `out_types`
:::note Note
Buffer operations in branches automatically update the global chain state to
maintain mutation ordering constraints
:::
### `constant()` {#max.graph.ops.constant}
> max.graph.ops.constant(value, dtype=None, device=None)
Adds a node representing a constant operation.
The value of this constant will have the type TensorType with the
same shape as value. If value is a scalar type, it will create a TensorType with 0 dimensions.
The constant will be loaded with the specified dtype.
If the constant does not fit within the specified dtype, an error is raised.
Warning: Loading the constant could result in precision loss.
For example, loading 16777217 as a float32 will result in 16777216.0.
**Parameters:**
* value ([DLPackArray](../driver.md#max.driver.DLPackArray) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[Number | NestedArray]] | [float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The constant’s value.
* dtype ([DType](../dtype.md#max.dtype.DType) | None) – The constant tensor’s element type.
* device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef) | None) – The device the constant lives on.
**Returns:**
A graph value containing the constant data as an attribute.
### `constant_external()` {#max.graph.ops.constant_external}
> max.graph.ops.constant\_external(name, type)
Registers an external constant (weight) in the graph of a given type.
Two external constants with the same name and type refer to the same weight.
Two external constants with the same name and different types are
incompatible and will fail compilation.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name of the external constant.
This should be the fully-qualified weight name and must be unique.
* type ([TensorType](type.md#max.graph.type.TensorType)) – The type of the constant value.
**Returns:**
A tensor value of the specified type, representing the weight value
associated with the name at compile time.
### `conv2d()` {#max.graph.ops.conv2d}
> max.graph.ops.conv2d(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF)
Computes the 2-D convolution product of the input with the given filter, bias,
strides, dilations, paddings, and groups.
The op supports 2-D convolution, with the following layout assumptions:
* input x has NHWC layout, i.e.,
(batch\_size, height, width, in\_channels)
* filter has layout RSCF, i.e.,
(height, width, in\_channels / num\_groups, out\_channels)
* bias has shape (out\_channels,)
The padding values are expected to take the form (pad\_dim1\_before,
pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding
0’s before and after the indicated spatial dimensions in input. In 2-D
convolution, dim1 here represents H and dim2 represents W. In Python like
syntax, padding a 2x3 spatial input with \[0, 1, 2, 1] would yield:
```python
input = [
[1, 2, 3],
[4, 5, 6]
]
## Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
## Shape is 3x6
```
This op currently only supports strides and padding on the input.
**Parameters:**
* input – An NHWC input tensor to perform the convolution upon.
* filter (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The convolution filter in RSCF layout:
(height, width, in\_channels / num\_groups, out\_channels).
* stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the convolution operation.
* dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel points.
* padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The amount of padding applied to the input.
* groups ([int](https://docs.python.org/3/library/functions.html#int)) – When greater than 1, divides the convolution into multiple
parallel convolutions. The number of input and output
channels must both be divisible by the number of groups.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
* bias (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None)
* input\_layout ([ConvInputLayout](type.md#max.graph.type.ConvInputLayout))
* filter\_layout ([FilterLayout](type.md#max.graph.type.FilterLayout))
**Returns:**
A symbolic tensor value with the convolution applied.
### `conv2d_transpose()` {#max.graph.ops.conv2d_transpose}
> max.graph.ops.conv2d\_transpose(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), output\_paddings=(0, 0), bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF)
Computes the 2-D deconvolution of the input with the given filter,
strides, dilations, paddings, and groups.
The op supports the transpose (gradient) of convolution, with the following layout assumptions:
(note the out\_channel is w\.r.t. the original convolution)
* input x has NHWC layout, i.e.,
(batch\_size, height, width, in\_channels)
* filter has layout RSCF, i.e.,
(kernel\_height, kernel\_width, out\_channels, in\_channels)
* bias has shape (out\_channels,)
The padding values are expected to take the form in the form \[\[0, 0], \[pad\_top, pad\_bottom],
\[pad\_left, pad\_right], \[0, 0]].
This op effectively computes the gradient of a convolution with
respect to its input (as if the original convolution operation had the same
filter and hyperparameters as this op). A visualization of the computation
can be found in .
The padding values are expected to take the form (pad\_dim1\_before,
pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding
0’s before and after the indicated spatial dimensions in input. In 2D
ConvTranspose, dim1 here represents H\_out and dim2 represents W\_out. In
python like syntax, padding a 2x4 spatial output with \[0, 1, 2, 1] would
yield:
```python
output = [
[1, 2, 3, 4],
[5, 6, 7, 8]
]
## Shape is 2x4
padded_input = [
[3],
]
## Shape is 1x1
```
**Parameters:**
* input – An NHWC input tensor to perform the convolution upon.
* filter (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The convolution filter in RSCF layout:
(height, width, out\_channels, in\_channels).
* stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding window for each dimension of input.
If a single value is given it is replicated in the H and W dimension.
By default the N and C dimensions are set to 0.
* dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel points.
* padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The amount of padding applied to the input.
* output\_paddings ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – this argument is meant to resolve the ambiguity of multiple
potential output shapes when any stride is greater than 1. Basically,
we’ll add output\_paddings\[i] number of zeros at the end of output’s ith
axis. We only support output\_paddings = 0.
* bias (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None) – tensor of shape (out\_channels,)
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
* input\_layout ([ConvInputLayout](type.md#max.graph.type.ConvInputLayout))
* filter\_layout ([FilterLayout](type.md#max.graph.type.FilterLayout))
**Returns:**
A symbolic tensor value with the convolution applied.
### `conv3d()` {#max.graph.ops.conv3d}
> max.graph.ops.conv3d(x, filter, stride=(1, 1, 1), dilation=(1, 1, 1), padding=(0, 0, 0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.QRSCF)
Computes the 3-D convolution product of the input with the given filter,
strides, dilations, paddings, and groups.
The op supports 3-D convolution, with the following layout assumptions:
* input has NDHWC layout, i.e.,
(batch\_size, depth, height, width, in\_channels)
* filter has layout RSCF, i.e.,
(depth, height, width, in\_channels / num\_groups, out\_channels)
The padding values are expected to take the form (pad\_dim1\_before,
pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding
0’s before and after the indicated spatial dimensions in input. In 3-D
convolution, dim1 here represents D, dim2 represents H and dim3 represents W. In Python like
syntax, padding a 2x3 spatial input with \[0, 1, 2, 1] would yield:
```python
input = [
[1, 2, 3],
[4, 5, 6]
]
## Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
## Shape is 3x6
```
This op currently only supports strides and padding on the input.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – An NDHWC input tensor to perform the convolution upon.
* filter (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The convolution filter in RSCF layout:
(depth, height, width, in\_channels / num\_groups, out\_channels).
* stride ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the convolution operation.
* dilation ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel points.
* padding ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The amount of padding applied to the input.
* groups ([int](https://docs.python.org/3/library/functions.html#int)) – When greater than 1, divides the convolution into multiple
parallel convolutions. The number of input and output
channels must both be divisible by the number of groups.
* bias (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray) | None)
* input\_layout ([ConvInputLayout](type.md#max.graph.type.ConvInputLayout))
* filter\_layout ([FilterLayout](type.md#max.graph.type.FilterLayout))
**Returns:**
A symbolic tensor value with the convolution applied.
Output shape = (batch\_size, depth, height, width, out\_channels).
### `cumsum()` {#max.graph.ops.cumsum}
> max.graph.ops.cumsum(x, axis=-1, exclusive=False, reverse=False)
Computes the cumulative sum of the input tensor along the given axis.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to sum over.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the sum. If negative,
indexes from the last dimension. For example, a value of -1 will
compute the sum along the last dimension.
* exclusive ([bool](https://docs.python.org/3/library/functions.html#bool)) – If set, start at 0 and exclude the final element.
Otherwise, start with the first element. Said another way, cumsum
computes \[sum(x\[…, :i, …]) for i in range(x.shape\[axis])].
If exclusive is set, the bounds are instead range(1, x.shape\[axis]).
* reverse ([bool](https://docs.python.org/3/library/functions.html#bool)) – If set, start from the end. In other words, the first element
will be the total sum, with each element following counting
downwards; or \[sum(x\[…, i:, …]) for i in range(x.shape\[axis])].
**Returns:**
A symbolic tensor representing the result of the cumsum operation.
The tensor will have the same type as the input tensor. The computed
values will be the cumulative sum of the values along the given axis,
according to the specified parameters:
* if exclusive is set, the first value will be 0, and the last
value will be excluded from the sum
* if reverse is set, the sum will be computed starting at the
back of the axis back to the front, rather than front-to-back
### `custom()` {#max.graph.ops.custom}
> max.graph.ops.custom(name, device, values, out\_types, parameters=None)
Creates a node to execute a custom graph operation in the graph.
The custom op should be registered by annotating a function with the
[@compiler.register](/mojo/manual/decorators/compiler-register/)
decorator.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`.
* values ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The op function’s arguments.
* out\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The list of op function’s return type.
* parameters ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](../dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel.
* device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – Device that the op is assigned to.
This becomes a target parameter to the kernel.
**Returns:**
Symbolic values representing the outputs of the op in the graph.
These correspond 1:1 with the types passed as `out_types`.
### `dequantize()` {#max.graph.ops.dequantize}
> max.graph.ops.dequantize(encoding, quantized)
Dequantizes a quantized tensor to floating point.
NOTE: Currently this supports Q4\_0, Q4\_K, and Q6\_K encodings only.
**Parameters:**
* encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding)) – The quantization encoding to use.
* quantized ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The quantized tensor to dequantize.
### `distributed_broadcast()` {#max.graph.ops.distributed_broadcast}
> max.graph.ops.distributed\_broadcast(input, signal\_buffers)
Broadcast tensor from source GPU to all GPUs.
This op is a collective operation which broadcasts a tensor from the source
GPU (where the input tensor resides) to all participating GPUs. Each GPU
receives a copy of the input tensor.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – Input tensor to broadcast. The device where this tensor resides
becomes the root/source of the broadcast.
* signal\_buffers ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[BufferValue](BufferValue.md#max.graph.BufferValue) | HasBufferValue]) – Device buffer values used for synchronization.
The number of signal buffers determines the number of participating
GPUs.
**Returns:**
List of output tensors, one per device. Each output tensor has the
same shape and dtype as the input tensor.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input tensor device is not found in signal buffer devices,
if devices are not unique, or if there are fewer than 2 signal buffers.
### `div()` {#max.graph.ops.div}
> max.graph.ops.div(lhs, rhs)
Divides two symbolic tensors using true division (Python operator /).
For integer operands, this performs true division by promoting to float,
matching Python’s / operator behavior. For floating-point operands,
this performs standard floating-point division.
Creates a new op node to compute the division of two symbol tensor values
and adds it to the graph, returning the symbolic result.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The symbol to use as left side of the division.
* rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The symbol to use as right side of the division.
**Returns:**
A symbolic tensor value representing the output of the division. The
result will have:
: - floating-point dtype for integer operands, promoted dtype for mixed types
* the same shape as the broadcast of the two input shapes.
**Raises:**
* Error – If the input values’ shapes are not compatible for broadcasting.
* Error – If one of the input values has an unsupported dtype.
* Error – If the two symbols are parts of different graphs.
### `exp()` {#max.graph.ops.exp}
> max.graph.ops.exp(x)
Computes the elementwise exp (exponential) function of a symbolic tensor.
Creates a new op node to compute the elementwise exponential function of a
symbolic tensor and adds it to the graph, returning the symbolic result.
The exp function is fundamental in neural networks, used in attention
mechanisms, activation functions, and probability distributions.
```python
import max.functional as F
from max.tensor import Tensor
## Create input tensor
x = Tensor.constant([0.0, 1.0, 2.0])
## Compute exponential
result = F.exp(x)
print(result)
## Output: [1.0, 2.718..., 7.389...]
## (e^0 = 1, e^1 ≈ 2.718, e^2 ≈ 7.389)
```
`exp` is defined as `exp(x) = e^x`, where `e` is Euler’s number.
**Parameters:**
* value – The symbolic tensor to use as the input to the exp function
computation.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the exp
value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `flatten()` {#max.graph.ops.flatten}
> max.graph.ops.flatten(x, start\_dim=0, end\_dim=-1)
Flattens the specified dims of a symbolic tensor.
The number and order of the elements in the tensor is unchanged.
All dimensions from start\_dim to end\_dim (inclusive) are merged into a single output dim.
### `fold()` {#max.graph.ops.fold}
> max.graph.ops.fold(input, output\_size, kernel\_size, stride=1, dilation=1, padding=0)
Combines an array of sliding blocks into a larger containing tensor.
The input tensor must have shape `(N, C * kernel_sizes, L)` where `N` is
the batch dimension, `C` is the number of channels, `kernel_sizes` is
the product of the kernel sizes, and `L` is the number of local blocks.
The resulting output tensor will have shape
`(N, C, output_shape[0], output_shape[1])`.
`L`, the number of blocks, must be equivalent to:
`prod((output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1)`
where `d` is over all spatial dimensions.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The 3D tensor to fold with shape `(N, C * kernel sizes, L)`.
* output\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Spatial dimensions of the output tensor. Must be a tuple of two ints.
* kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The size of the sliding blocks. Must be a tuple of two ints.
* stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding blocks in the input dimension
(can be an int or a tuple of two ints).
* dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel elements.
(can be an int or a tuple of two ints).
* padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – 0-paddings to be added on both sides of the inputs.
(can be an int or a tuple of two ints).
**Returns:**
The folded 4D tensor with shape `(N, C, output_shape[0], output_shape[1])`.
### `gather()` {#max.graph.ops.gather}
> max.graph.ops.gather(input, indices, axis)
Selects elements out of an input tensor by index.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to select elements from.
* indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of index values to use for selection.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension which `indices` indexes from `input`. If negative,
indexes relative to the end of the input tensor. For instance,
`gather(input, indices, axis=-1)` will index against the last
dimension of `input`.
**Returns:**
A new symbolic tensor representing the result of the gather
operation.
### `gather_nd()` {#max.graph.ops.gather_nd}
> max.graph.ops.gather\_nd(input, indices, batch\_dims=0)
Selects elements out of an input tensor by N-dimensional index.
This operation performs N-dimensional indexing into `input` using `indices`.
Unlike [`gather()`](#max.graph.ops.gather), which indexes along a single axis, `gather_nd()` allows
indexing along multiple dimensions simultaneously.
```python
input_shape = ["a", "b", "c", "d", "e"]
indices_shape = ["a", "f", 3]
input_type = TensorType(DType.bfloat16, input_shape)
indices_type = TensorType(DType.int32, indices_shape)
with Graph("gather_nd", input_types=[input_type, indices_type]) as graph:
input, indices = graph.inputs
gathered = ops.gather_nd(input, indices, batch_dims=1)
print(gathered.type)
## Output: TensorType(dtype=DType.bfloat16, shape=["a", "f", "e"])
```
In this example:
* `batch_dims` is 1, so there’s 1 shared dimension at the beginning.
* `indices` has an additional dimension “f” which becomes part of the output.
* The last dimension of `indices` is the index vector; values in this vector
are interpreted to be indices into “b”, “c”, and “d”.
* Since `batch_dims (1) + index size (3) < input.rank (5)`, the remaining
dimensions (in this case “e”) are sliced into the output as features.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to select elements from.
* indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of index values to use for selection.
The last dimension of this tensor must be static. This dimension
will be used to index or slice into `input` immediately following
`batch_dims` initial dimensions. The size of this index dimension
is the number of dimensions it specifies.
* batch\_dims ([int](https://docs.python.org/3/library/functions.html#int)) – The number of leading batch dimensions shared by
`input` and `indices`; 0 by default. `input` and `indices` must
exactly match up to their first `batch_dims` dimensions. This
function does not broadcast.
**Returns:**
A new symbolic tensor representing the result of the gather operation.
The output will have the same dtype as `input`, and will have shape
depending on the inputs, in this order:
* `input.shape[:batch_dims]` – The “broadcast” dimensions (though note
that this function does not broadcast). These dimensions must be
identical between `input` and `indices`.
* `indices.shape[batch_dims:-1]` – The “gather” dimensions; this allows
multi-dimensional tensors of indices. The last dimension is the index vector.
* `input.shape[batch_dims + indices.shape[-1]:]` – The “slice” dimensions.
If `batch_dims` < `input.rank - indices.shape[-1]` (again, this last
is the index vector), then any following dimensions of the inputs are
taken entirely as though slicing.
### `gelu()` {#max.graph.ops.gelu}
> max.graph.ops.gelu(x, approximate='none')
Computes the elementwise gelu of a symbolic tensor.
Creates a new op node to compute the elementwise gelu of a
symbolic tensor and adds it to the graph, returning the symbolic result.
For `approximate == "none"`, the exact gelu function is computed.
For `approximate == "tanh"`, the approximation:
$$
gelu(x) = 0.5 * x * (1.0 + tanh(0.7978845608028654 * (x + 0.044715 * x**3)))
$$
is used.
For `approximate == "quick"`, the approximation:
$$
gelu(x) = sigmoid(1.702 * x) * x
$$
is used.
**Parameters:**
* value – The symbolic tensor to use as the input to the gelu
computation.
* x ([TensorValue](TensorValue.md#max.graph.TensorValue))
* approximate ([str](https://docs.python.org/3/library/stdtypes.html#str))
**Returns:**
A new symbolic tensor value representing the output of the gelu
value computation.
**Raises:**
* Error – If the symbol doesn’t represent a tensor value.
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the approximation method is invalid.
### `hann_window()` {#max.graph.ops.hann_window}
> max.graph.ops.hann\_window(window\_length, device, periodic=True, dtype=float32)
Calculate a Hann window for a given length.
Hann window function:
$$
H[n] = 1/2 [1 - cos(2 * pi * n / (N - 1))]
$$
where N is window\_length.
**Parameters:**
* window\_length ([int](https://docs.python.org/3/library/functions.html#int)) – The length of the window.
* device ([DeviceRef](type.md#max.graph.type.DeviceRef)) – The device to run the operation on.
* periodic ([bool](https://docs.python.org/3/library/functions.html#bool)) – bool
flag determines whether the returned window trims off the last
duplicate value from the symmetric window and is ready to be used
as a periodic window with functions like stft().
hann\_window(L, periodic=True) == hann\_window(L + 1, periodic=False)\[:-1])
* dtype ([DType](../dtype.md#max.dtype.DType)) – The desired data type of the output tensor.
**Returns:**
A 1-D tensor of size (window\_length,) containing the window.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If window\_length is negative.
* [TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If window\_length is not an integer.
### `inplace_custom()` {#max.graph.ops.inplace_custom}
> max.graph.ops.inplace\_custom(name, device, values, out\_types=None, parameters=None)
Creates a node to execute an in-place custom graph operation in the graph.
The custom op should be registered by annotating a function with the
[@compiler.register](/mojo/manual/decorators/compiler-register/)
decorator.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The op name provided to `@compiler.register`.
* device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – Device that the op is assigned to.
This becomes a target parameter to the kernel.
* values ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The op function’s arguments.
* parameters ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [DType](../dtype.md#max.dtype.DType)] | None) – Dictionary of extra parameters expected by the kernel.
* out\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Type](type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None)
### `irfft()` {#max.graph.ops.irfft}
> max.graph.ops.irfft(input\_tensor, n=None, axis=-1, normalization=Normalization.BACKWARD, input\_is\_complex=False, buffer\_size\_mb=512)
Compute the inverse real FFT of the input tensor.
**Parameters:**
* input\_tensor (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input tensor to compute the inverse real FFT of.
* n ([int](https://docs.python.org/3/library/functions.html#int) | None) – The size of the output tensor. Must be an int, and cannot be a
symbolic Buffer. The input tensor will be padded or truncated to
n // 2 + 1 along the specified axis.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to compute the inverse real FFT of.
* normalization (Normalization | [str](https://docs.python.org/3/library/stdtypes.html#str)) – The normalization to apply to the output tensor.
Can be “backward”, “ortho”, or “forward”. When “backward”, the
output is divided by n. When “ortho”, the output is divided by
sqrt(n). When “forward”, no normalization is applied.
* input\_is\_complex ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether the input tensor is already interleaved
complex. The last dimension of the input tensor must be 2, and is
excluded from the dimension referred to by axis.
* buffer\_size\_mb ([int](https://docs.python.org/3/library/functions.html#int)) – The estimated size of a persistent buffer to use for
storage of intermediate results. Needs to be the same across multiple
calls to irfft within the same graph. Otherwise, multiple buffers
will be allocated.
**Returns:**
The inverse real FFT of the input tensor. The shape of the output tensor
is the same as the shape of the input tensor, except for the axis that
the inverse real FFT is computed over, which is replaced by n.
* input ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The input tensor to normalize.
* gamma (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The gamma parameter of the normalization.
* beta (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The beta parameter of the normalization.
* epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – The epsilon parameter of the normalization.
**Returns:**
A graph tensor value with the normalization applied.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If gamma size doesn’t match the last dimension of input.
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If beta size doesn’t match the last dimension of input.
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If epsilon is not positive.
### `log()` {#max.graph.ops.log}
> max.graph.ops.log(x)
Computes the elementwise natural logarithm of a symbolic tensor.
Creates a new op node to compute the elementwise natural logarithm of a
symbolic tensor and adds it to the graph, returning the symbolic result.
The natural logarithm is used in loss functions, normalization, and
probability calculations in machine learning.
```python
import max.functional as F
from max.tensor import Tensor
## Create input tensor (positive values only)
x = Tensor.constant([1.0, 2.718, 7.389, 20.0])
## Compute natural logarithm
result = F.log(x)
print(result)
## Output: [0.0, 1.0, 2.0, 2.996...]
## (log(1) = 0, log(e) = 1, log(e^2) = 2)
```
The natural logarithm function `log` is defined as the inverse of the
exponential function `exp()`. In other words, it computes the value `y` in
the equation `x = e^y` where `e` is Euler’s number.
`log(x)` is undefined for `x <= 0` for real numbers. Complex numbers
are currently unsupported.
**Parameters:**
* value – The symbolic tensor to use as the input to the natural logarithm
computation.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the natural logarithm
value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `masked_scatter()` {#max.graph.ops.masked_scatter}
> max.graph.ops.masked\_scatter(input, mask, updates, out\_dim)
Creates a new symbolic tensor where the updates are written to input where mask is true.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to.
* mask (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of boolean values to update.
* updates (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input.
* out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The new data-dependent dimension.
**Returns:**
A new symbolic tensor representing the result of the masked\_scatter operation.
### `matmul()` {#max.graph.ops.matmul}
> max.graph.ops.matmul(lhs, rhs)
Computes the matrix multiplication of two tensor graph values.
Performs general matrix multiplication with broadcasting. Matrix multiplication
is fundamental to neural networks, used for linear transformations, attention
mechanisms, and fully connected layers.
```python
from max.tensor import Tensor
## Create two 2x2 matrices
x = Tensor.constant([[1.0, 2.0], [3.0, 4.0]]) # Shape: (2, 2)
w = Tensor.constant([[5.0, 6.0], [7.0, 8.0]]) # Shape: (2, 2)
## Matrix multiply using @ operator (uses matmul internally)
result = x @ w
print("Matrix multiplication result:")
print(result)
## Output: [[19.0, 22.0],
## [43.0, 50.0]]
## Computed as: result[i,j] = sum(x[i,k] * w[k,j])
## Can also call directly via functional API
import max.functional as F
result2 = F.matmul(x, w)
## Same result as x @ w
```
If the lhs is 1D, it will be reshaped to `1xD`.
If the rhs is 1D, it will be reshaped to `Dx1`.
In both cases, the additional 1 dimensions will be removed from the
output shape.
For the multiplication, the innermost (rightmost) 2 dimensions are treated
as a matrix.
The lhs matrix will have the shape `MxK`.
The rhs matrix will have the shape `KxN`.
The output will have the shape MxN
The `K` dimensions must be equivalent in both matrices.
The remaining outer dimensions will be broadcasted.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The left-hand side input tensor.
* rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The right-hand side input tensor.
* location – An optional location for a more specific error message.
**Returns:**
A tensor graph value representing the matrix product of `lhs` and `rhs`.
For 2D inputs, the output shape is `(M, N)` where `lhs` is `(M, K)`
and `rhs` is `(K, N)`. For higher-dimensional inputs, batch
dimensions are preserved and the operation is applied to the last two
dimensions of each input.
### `max()` {#max.graph.ops.max}
> max.graph.ops.max(x, y=None, /, axis=None)
Overload for ops.elementwise.max and ops.reduction.max.
* If two tensors are provided, axis is ignored and returns an elementwise maximum.
* If one tensor is provided, compute ops.reduction.max on the tensor and axis.
### `max_pool2d()` {#max.graph.ops.max_pool2d}
> max.graph.ops.max\_pool2d(input, kernel\_size, stride=1, dilation=1, padding=0, ceil\_mode=False)
Perform a 2D max pooling operation on the input tensor.
This function applies a 2D max pooling operation to the input tensor \[N, H, W, C].
The pooling operation slides a window of size kernel\_size over the input
tensor, and selects the maximum value within each window.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to perform the pooling operation on.
* kernel\_size ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)], [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The size of the sliding blocks.
* stride ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The stride of the sliding blocks in the input dimension.
* dilation ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – The spacing between the kernel elements.
* padding ([int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – 0-paddings to be added on both sides of the inputs.
* ceil\_mode ([bool](https://docs.python.org/3/library/functions.html#bool)) – If true, use ceil instead of floor to compute the output shape.
### `mean()` {#max.graph.ops.mean}
> max.graph.ops.mean(x, axis=-1)
Reduces a symbolic tensor using a mean operation.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension. For example, a value of -1 will
compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the mean operation.
The tensor will have the same rank as the input tensor, and the same
shape except along the `axis` dimension which will have size 1.
### `min()` {#max.graph.ops.min}
> max.graph.ops.min(x, y=None, /, axis=None)
Overload for ops.elementwise.min and ops.reduction.min.
* If two tensors are provided, axis is ignored and returns an elementwise minimum.
* If one tensor is provided, compute ops.reduction.min on the tensor and axis.
### `nonzero()` {#max.graph.ops.nonzero}
> max.graph.ops.nonzero(x, out\_dim)
Returns the indices of all nozero elements in a tensor.
Returns a tensor of indices of the nonzero values in the given tensor. The
return value is a 2D tensor of shape `[out_dim x rank_in]`, where
out\_dim is the number of nonzero elements in the input tensor, and
rank\_in is the rank of the input tensor. Indices are generated in
row-major order.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor.
* out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The newly generated dimension that is sized for the number of
nonzero elements.
### `outer()` {#max.graph.ops.outer}
> max.graph.ops.outer(lhs, rhs)
Computes the outer product of two symbolic vectors.
**Parameters:**
* lhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The left side of the product. Whatever its shape,
it will be flattened to a rank-1 vector.
* rhs (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The right side of the product. Whatever its shape,
it will be flattened to a rank-1 vector. Must have the
same number of elements as lhs.
**Returns:**
A symbolic tensor representing the
[outer product](\[https://en.wikipedia.org/wiki/Outer_product]\(https://en.wikipedia.org/wiki/Outer_product\))
of the two input vectors. It will have rank 2, with the dimension
sizes being the number of elements of lhs and rhs respectively.
### `pad()` {#max.graph.ops.pad}
> max.graph.ops.pad(input, paddings, mode='constant', value=0)
Pads a tensor with constant values.
Adds padding to the input tensor using the specified padding values.
Currently only constant padding mode is supported.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to pad.
* paddings ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Sequence of padding values. The padding values are applied
symmetrically to each dimension. For a tensor with rank N,
paddings should contain 2\*N values: [pad\_before\_dim0, pad\_after\_dim0,
pad\_before\_dim1, pad\_after\_dim1, …].
* mode ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['constant']) – The padding mode. Currently only “constant” is supported.
* value (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The constant value to use for padding.
### `permute()` {#max.graph.ops.permute}
> max.graph.ops.permute(x, dims)
Permutes all dimensions of a symbolic tensor.
**Parameters:**
* input – The input symbolic tensor to transpose.
* dims ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – The desired ordering of the dimensions in the output tensor.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor with the dimensions permuted to match the passed in order.
It has the same elements and dtype, but the order of the elements
is different according to the permutation.
### `print()` {#max.graph.ops.print}
> max.graph.ops.print(value, label='debug\_tensor')
Prints the value of a tensor or a string during graph execution.
This function is used to output the current value of a tensor and is
primarily used for debugging purposes within the context of the Max
Engine and its graph execution framework. This is particularly useful to
verify the intermediate results of your computations are as expected.
By printing the tensor values, you can visualize the data flowing through the
graph, which helps in understanding how the operations are transforming
the data.
When labeling the function you can assign the output, making it easier to
identify which tensor’s value is being printed, especially when there are
multiple print statements in a complex graph.
```python
def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, Any]:
input_type = TensorType(dtype=DType.float32, shape=(1,), device=DeviceRef.CPU())
with Graph(
"simple_add_graph", input_types=(input_type, input_type)
) as graph:
lhs, rhs = graph.inputs
out = ops.add(lhs, rhs)
ops.print(out, label="addition_output") # Pass the output tensor here
graph.output(out)
print("final graph:", graph)
```
**Parameters:**
* value ([str](https://docs.python.org/3/library/stdtypes.html#str) | [TensorValue](TensorValue.md#max.graph.TensorValue)) – The value to print. Can be either a string or a TensorValue.
* label ([str](https://docs.python.org/3/library/stdtypes.html#str)) – A label to identify the printed value. Defaults to
`debug_tensor`.
**Return type:**
None
### `qmatmul()` {#max.graph.ops.qmatmul}
> max.graph.ops.qmatmul(encoding, config, lhs, \*rhs)
Performs matrix multiplication between floating point and quantized
tensors.
This quantizes the `lhs` floating point value to match the encoding of the
`rhs` quantized value, performs matmul, and then dequantizes the result.
Beware that, compared to a regular matmul op, this one expects the `rhs`
value to be transposed. For example, if the `lhs` shape is \[32, 64], and
the quantized `rhs` shape is also `[32, 64]`, then the output shape is
`[32, 32]`.
That is, this function returns the result from:
> dequantize(quantize(lhs) @ transpose(rhs))
The last two dimensions in `lhs` are treated as matrices and multiplied
by `rhs` (which must be a 2D tensor). Any remaining dimensions in `lhs`
are broadcast dimensions.
NOTE: Currently this supports Q4\_0, Q4\_K, and Q6\_K encodings only.
**Parameters:**
* encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding)) – The quantization encoding to use.
* lhs ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The non-quantized, left-hand-side of the matmul.
* \*rhs ([TensorValue](TensorValue.md#max.graph.TensorValue)) – The transposed and quantized right-hand-side of the matmul and
auxiliary tensor (if has). Must be rank 2 and in a supported
\[quantization encoding] (/max/api/mojo/graph/quantization/).
* config ([QuantizationConfig](quantization.md#max.graph.quantization.QuantizationConfig) | None)
### `range()` {#max.graph.ops.range}
> max.graph.ops.range(start, stop, step=1, out\_dim=None, \*, dtype, device)
Creates a sequence of numbers. The sequence goes from start with
increments of size step up to (but not including) stop. All arguments
are mandatory and must have the same element type.
Note the following restrictions on input values:
1. step must be non-zero
2. stop - start must be zero or have the same sign as step
**Parameters:**
* start (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The start of the range to generate.
* stop (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The range will be generated up to, but not including, this value.
* step (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The step size for the range.
* out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None) – The expected output dimensions returned by the range op.
These will be assert at graph execution time to be correct.
* device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – Device of the result tensor.
* dtype ([DType](../dtype.md#max.dtype.DType)) – Data type of the result tensor. If not specified, defaults to
float32 for numeric inputs or infers from tensor inputs.
**Returns:**
A symbolic tensor value containing the defined range of values.
### `rebind()` {#max.graph.ops.rebind}
> max.graph.ops.rebind(x, shape, message='', layout=None)
Rebinds a symbolic tensor to a specified set of dimensions.
This does not mutate the symbolic tensor passed in, but instead adds a
runtime assert that the input symbolic shape is equivalent to
`out_dims` shape. For example, if the input tensor shape has
dynamic/unknown sizes, this will assert a fixed sizes that may be required
for a subsequent operation.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to rebind.
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The symbolic shape to assert for `x`, as a list of
[`Dim`](/max/api/python/graph/type/Dim) values.
* message ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The message printed if the rebind fails at runtime.
* layout ([FilterLayout](type.md#max.graph.type.FilterLayout) | None) – A layout of the weights used by some operations like conv.
**Returns:**
A symbolic tensor with the same elements and shape as the given tensor,
but with the symbolic shape asserted to `out_dims`.
### `relu()` {#max.graph.ops.relu}
> max.graph.ops.relu(x)
Computes the elementwise ReLU (Rectified Linear Unit) of a symbolic tensor.
Creates a new op node to compute the elementwise ReLU of a symbolic tensor
and adds it to the graph, returning the symbolic result. ReLU is defined as
`relu(x) = max(0, x)`, setting all negative values to zero while leaving
positive values unchanged.
ReLU is one of the most common activation functions in neural networks due to
its computational efficiency and effectiveness in addressing the vanishing
gradient problem.
```python
import max.functional as F
from max.tensor import Tensor
## Create input with negative and positive values
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])
## Apply ReLU activation
result = F.relu(x)
print(result)
## Output: [[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]
## Negative values become 0, positive values unchanged
```
**Parameters:**
* value – The symbolic tensor to use as the input to the relu
computation.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the relu
value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `repeat_interleave()` {#max.graph.ops.repeat_interleave}
> max.graph.ops.repeat\_interleave(x, repeats, axis=None, out\_dim=None)
Repeats elements of a tensor along the given dimension.
Modeled after `torch.repeat_interleave`, with the constraint that
For example, given `repeats=2` and the following input:
```python
## Input tensor with shape (2, 2)
input = TensorValue(x) # Contains [[1.0, 2.0], [3.0, 4.0]]
```
`repeat_interleave` with `axis=0`:
```python
## Output tensor with shape (4, 2)
output = repeat_interleave(input, repeats=2, axis=0)
## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]
```
`repeat_interleave` with `axis=1`:
```python
## Output tensor with shape (2, 4)
output = repeat_interleave(input, repeats=2, axis=1)
## Contains [[1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0]]
```
`repeat_interleave` with `axis=None` (the default):
`repeat_interleave` with `repeats=[2, 3]` and `axis=0`:
```python
repeat_value = TensorValue([2, 3])
## Output tensor with shape (5, 2)
output = repeat_interleave(input, repeats=repeat_value, axis=0)
## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
```
```python
## Output tensor with shape (8,)
output = repeat_interleave(input, repeats=2) # axis = None
## Contains [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
```
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor.
* repeats ([int](https://docs.python.org/3/library/functions.html#int) | [TensorValue](TensorValue.md#max.graph.TensorValue)) – The number of repetitions for each element.
* axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The dimension along which to repeat values. If axis is not
specified or None (the default), flatten the input array
and repeat the flattened values.
* out\_dim ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None)
**Returns:**
A symbolic tensor with the elements interleaved.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If `repeats` non-positive or if `axis` is out of range.
### `reshape()` {#max.graph.ops.reshape}
> max.graph.ops.reshape(x, shape)
Reshapes a symbolic tensor.
The number and order of the elements in the tensor is unchanged.
In other words, if you were to iterate over elements in the tensor
by major dimension to minor dimension, the iteration order would stay
the same.
If a value of -1 is present in the shape, that dimension becomes
an automatically calculated dimension collecting all unspecified dimensions.
Its length becomes the number of elements in the original tensor
divided by the product of elements of the reshape.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to reshape.
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The new shape as a list of dimensions.
A single dimension may be -1.
**Returns:**
A symbolic tensor with the same elements as the original tensor, but
in a new shape. Its symbolic shape is the same as `shape`.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – if input and target shapes’ number of elements mismatch.
### `resize()` {#max.graph.ops.resize}
> max.graph.ops.resize(input, shape, interpolation=InterpolationMode.BILINEAR)
Resize the input tensor to the given shape.
This function resizes a tensor using the specified interpolation method.
The tensor is expected to have NCHW format (batch, channels, height, width).
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor to resize. Must have rank 4 in NCHW format.
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Desired output shape of length 4 corresponding to (N, C, H, W).
* interpolation ([InterpolationMode](#max.graph.ops.InterpolationMode)) – Desired interpolation enum defined by InterpolationMode.
Default is InterpolationMode.BILINEAR. Currently only BICUBIC is
supported.
**Returns:**
A resized tensor with the shape specified by the shape argument.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input doesn’t have rank 4, shape has wrong number
of elements, or unsupported interpolation mode is specified.
* [NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If single integer size or non-BICUBIC interpolation
mode is specified.
### `scatter()` {#max.graph.ops.scatter}
> max.graph.ops.scatter(input, updates, indices, axis=-1)
Creates a new symbolic tensor where the updates are written to input according to indices.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to.
* updates (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input.
* indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The positions in input to update.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which indices indexes into.
**Returns:**
A new symbolic tensor representing the result of the scatter operation.
### `scatter_nd()` {#max.graph.ops.scatter_nd}
> max.graph.ops.scatter\_nd(input, updates, indices)
Creates a new symbolic tensor where the updates are scattered into input at specified indices.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to.
* updates (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input.
* indices (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – A tensor of indices specifying where to write updates.
Shape should be \[num\_updates, rank] for full indexing or
\[num\_updates, k] for partial indexing where k < rank.
**Returns:**
A new symbolic tensor representing the result of the scatter\_nd operation.
### `shape_to_tensor()` {#max.graph.ops.shape_to_tensor}
> max.graph.ops.shape\_to\_tensor(shape)
Converts a shape to a tensor.
This is useful for using a shape attribute in an op that expects a tensor
value.
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – the shape attribute of a tensor value.
**Returns:**
The TensorValue containing the same value as shape.
**Example:**
```pycon
>>> x = ops.constant(np.zeros((1,)), DType.int64, device=DeviceRef.CPU())
>>> result = ops.stack([
... x,
... ops.shape_to_tensor(x.shape),
... ])
TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)])
```
### `shard_and_stack()` {#max.graph.ops.shard_and_stack}
> max.graph.ops.shard\_and\_stack(inputs, devices, axis=0)
Shards a list of input tensors along a specified axis, producing multiple outputs.
This operation takes multiple input tensors, splits each along the specified axis
into len(devices) chunks, and returns one output tensor per device. Each output
contains the chunks at the corresponding index stacked from all inputs along
a new dimension 0.
This is useful for distributing model weights across multiple devices in
tensor parallel configurations.
For example, with 2 inputs A and B, axis=0, and 2 devices:
* Input A shape \[10, D], Input B shape \[10, D]
* Output 0: stack(\[A\[0:5], B\[0:5]]) -> shape \[2, 5, D] on devices\[0]
* Output 1: stack(\[A\[5:10], B\[5:10]]) -> shape \[2, 5, D] on devices\[1]
With axis=1 and 2 devices:
* Input A shape \[D, 10], Input B shape \[D, 10]
* Output 0: stack(\[A\[:, 0:5], B\[:, 0:5]]) -> shape \[2, D, 5] on devices\[0]
* Output 1: stack(\[A\[:, 5:10], B\[:, 5:10]]) -> shape \[2, D, 5] on devices\[1]
**Parameters:**
* inputs ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – A list of symbolic tensors to shard. All tensors must have
the same shape, dtype, and device.
* devices ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)]) – Target devices for each output tensor. The number of devices
determines the number of splits. Each output tensor
will be placed on the corresponding device. This enables direct
host-to-device transfer without intermediate CPU storage.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to split each input tensor. Defaults to 0.
Supports negative indexing (e.g., -1 for last axis).
**Returns:**
A list of len(devices) tensors, each with shape
\[num\_inputs, D0, …, Daxis//len(devices), …, Dn-1] where the input
shape is \[D0, …, Daxis, …, Dn-1]. Output i contains the stacked
chunks at position i from all input tensors, placed on devices\[i].
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If inputs list is empty, if devices list is empty, if input
tensors don’t have matching shapes, if the dimension size at the
axis is not evenly divisible by len(devices), or if axis is out of
bounds.
### `sigmoid()` {#max.graph.ops.sigmoid}
> max.graph.ops.sigmoid(x)
Computes the elementwise sigmoid activation of a symbolic tensor.
Creates a new op node to compute the elementwise sigmoid of a symbolic
tensor and adds it to the graph, returning the symbolic result. Sigmoid
is defined as `sigmoid(x) = 1 / (1 + exp(-x))`, mapping all input values
to the range (0, 1).
The sigmoid function is commonly used for binary classification tasks and
as an activation function in neural networks, particularly in output layers
for probability prediction.
```python
import max.functional as F
from max.tensor import Tensor
## Create input tensor
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])
## Apply sigmoid activation
result = F.sigmoid(x)
print(result)
## Output: [[0.119, 0.269, 0.5], [0.731, 0.881, 0.953]]
## All values mapped to range (0, 1)
```
**Parameters:**
* value – The symbolic tensor to use as the input to the sigmoid
computation.
* x ([TensorValue](TensorValue.md#max.graph.TensorValue))
**Returns:**
A new symbolic tensor value representing the output of the sigmoid
value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `silu()` {#max.graph.ops.silu}
> max.graph.ops.silu(x)
Computes the elementwise silu of a symbolic tensor.
Creates a new op node to compute the elementwise silu of a
symbolic tensor and adds it to the graph, returning the symbolic result.
`silu` is defined as `silu(x) = x * sigmoid(x)`.
**Parameters:**
* value – The symbolic tensor to use as the input to the silu
computation.
* x ([TensorValue](TensorValue.md#max.graph.TensorValue))
**Returns:**
A new symbolic tensor value representing the output of the silu
value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `slice_tensor()` {#max.graph.ops.slice_tensor}
> max.graph.ops.slice\_tensor(x, indices)
Slices out a subtensor view of the input tensor based on indices.
The semantics of [`slice_tensor()`](#max.graph.ops.slice_tensor) follow NumPy slicing semantics with the
following restrictions:
* Slice indices must not index out of `[-dim - 1, dim - 1]` for negative step,
or `[-dim, dim]` for positive step.
```python
## Reverse a tensor.
slice_tensor(x, [slice(None, None, -1)])
## Unsqueeze the second last dimension of a tensor.
slice_tensor(x, [..., None, slice(None)])
```
**Returns:**
The sliced subtensor of x.
**Parameters:**
* x ([TensorValue](TensorValue.md#max.graph.TensorValue))
* indices (SliceIndices)
### `split()` {#max.graph.ops.split}
> max.graph.ops.split(x, split\_sizes, axis=0)
Splits the input tensor into multiple tensors along a given dimension.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to split.
* split\_sizes ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Sizes of each output tensor. Must add up to the split
dimension axis.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension to split the input tensor. Must have a statically
known dimension size.
**Returns:**
A list of tensors with the same length as split\_sizes, where each
tensor has the same shape as the input except along the split dimension
axis, where the size is given by the corresponding element in
split\_sizes.
### `sqrt()` {#max.graph.ops.sqrt}
> max.graph.ops.sqrt(x)
Computes the elementwise square root of a symbolic tensor.
Creates a new op node to compute the elementwise square root of a symbolic
tensor and adds it to the graph, returning the symbolic result. Square root
is commonly used in normalization operations, distance calculations, and
implementing mathematical operations like standard deviation.
```python
import max.functional as F
from max.tensor import Tensor
## Create tensor with positive values
x = Tensor.constant([1.0, 4.0, 9.0, 16.0])
## Compute square root
result = F.sqrt(x)
print(result)
## Output: [1.0, 2.0, 3.0, 4.0]
## Note: sqrt requires non-negative values
## For tensors with negative values, use abs first:
y = Tensor.constant([1.0, -4.0, 9.0, -16.0])
result2 = F.sqrt(F.abs(y))
print(result2)
## Output: [1.0, 2.0, 3.0, 4.0]
```
**Parameters:**
* value – The symbolic tensor to use as the input to the sqrt
computation. If it’s not a floating-point DType, an exception will be raised.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the sqrt
value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `squeeze()` {#max.graph.ops.squeeze}
> max.graph.ops.squeeze(x, axis)
Removes a size-1 dimension from a symbolic tensor.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to squeeze.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension to remove from the input’s shape. If negative, this
indexes from the end of the tensor. For example,
`squeeze(v, -1)` squeezes the last dimension.
**Returns:**
A symbolic tensor with the same number of elements as the input tensor,
and whose rank is 1 less than the rank of the input tensor.
### `stack()` {#max.graph.ops.stack}
> max.graph.ops.stack(values, axis=0)
Stacks a list of tensors along a new axis.
**Parameters:**
* values ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)]) – A list of symbolic tensor values. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape plus 1. For instance,
`stack(vs, -1)` will create and stack along a new axis as the
last dimension, aad `stack(vs, -2)` will create and stack along a new
dimension which is inserted immediately before the last dimension.
**Returns:**
A new symbolic tensor representing the result of the stack. It will
have rank `n+1` where `n` is the rank of each input tensor. Its size
on each dimension other than `axis` will be the same as each input tensors’,
with the new axis inserted. Along the new dimension it will have size
`len(values)`.
### `sum()` {#max.graph.ops.sum}
> max.graph.ops.sum(x, axis=-1)
Reduces a symbolic tensor using a sum operation.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor for the operation.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis along which to compute the reduction. If negative,
indexes from the last dimension. For example, a value of -1 will
compute the reduction along the last dimension.
**Returns:**
A symbolic tensor representing the result of the sum operation.
The tensor will have the same rank as the input tensor, and the same
shape except along the `axis` dimension which will have size 1.
### `tanh()` {#max.graph.ops.tanh}
> max.graph.ops.tanh(x)
Computes the elementwise tanh (hyperbolic tangent) of a symbolic tensor.
Creates a new op node to compute the elementwise tanh of a symbolic tensor
and adds it to the graph, returning the symbolic result. Tanh is defined as
`tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`, mapping all input
values to the range (-1, 1).
The tanh function is commonly used as an activation function in recurrent
neural networks (RNNs) and as a hidden layer activation in feedforward networks.
Unlike sigmoid which maps to (0, 1), tanh is zero-centered, which can help
with gradient flow during training.
```python
import max.functional as F
from max.tensor import Tensor
## Create input tensor
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])
## Apply tanh activation
result = F.tanh(x)
print(result)
## Output: [[-0.964, -0.762, 0.0], [0.762, 0.964, 0.995]]
## All values mapped to range (-1, 1)
```
**Parameters:**
* value – The symbolic tensor to use as the input to the tanh
computation. If it’s not a floating-point DType, an exception will be raised.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray))
**Returns:**
A new symbolic tensor value representing the output of the tanh
value computation.
**Raises:**
Error – If the symbol doesn’t represent a tensor value.
### `tile()` {#max.graph.ops.tile}
> max.graph.ops.tile(x, repeats)
Returns a new Tensor as the result of copying the input tensor N\_i times
on each dimension, where N\_i = repeats\[i].
The i-th dimension of output shape will be the ith dimension of input shape
multiplied by N\_i.
### `top_k()` {#max.graph.ops.top_k}
> max.graph.ops.top\_k(input, k, axis=-1)
Returns tensor with only top K values along given axis.
**Parameters:**
* input (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input tensor from which to select top k.
* k ([int](https://docs.python.org/3/library/functions.html#int)) – The number of values to select from input.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The axis from which to select top k.
### `transfer_to()` {#max.graph.ops.transfer_to}
> max.graph.ops.transfer\_to(x, device)
Device-to-Device transfer operation.
This op transfers the input tensor from its current device over to another. A device represents a
computation unit, like CPU, GPU, etc. This op is useful for instance when working with
accelerators, like GPU, where for instance one may need to move data from GPU to GPU, or
from one GPU to CPU.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue) – The input tensor to transfer.
* device ([Device](../driver.md#max.driver.Device) | [DeviceRef](type.md#max.graph.type.DeviceRef)) – The device to transfer to.
### `transpose()` {#max.graph.ops.transpose}
> max.graph.ops.transpose(x, axis\_1, axis\_2)
Transposes two axes of a symbolic tensor.
For more information, see [`transpose()`](TensorValue.md#max.graph.TensorValue.transpose).
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to transpose.
* axis\_1 ([int](https://docs.python.org/3/library/functions.html#int)) – One of the two axes to transpose. If negative, this indexes
from the end of the tensor. For example,
`transpose(v, -1, -2)` transposes the last two axes.
* axis\_2 ([int](https://docs.python.org/3/library/functions.html#int)) – The other axis to transpose. May also be negative to index from
the end of the tensor.
**Returns:**
A new symbolic tensor with the two specified axes transposed.
It has the same elements and dtype, but the order of the elements
is different according to the transposition.
### `unsqueeze()` {#max.graph.ops.unsqueeze}
> max.graph.ops.unsqueeze(x, axis)
Inserts a size-1 dimension into a symbolic tensor.
**Parameters:**
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to unsqueeze.
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The index at which to insert a new dimension into the input’s
shape. Elements at that index or higher are shifted back.
If negative, it indexes relative 1 plus the rank of the tensor.
For example, `unsqueeze(v, -1)` adds a new dimension at the
end, and `unsqueeze(v, -2)` inserts the dimension immediately
before the last dimension.
**Returns:**
A symbolic tensor with the same number of elements as the input tensor,
whose rank is 1 larger than the rank of the input tensor. The result’s
shape at the `axis` dimension is a static dimension of size 1.
### `where()` {#max.graph.ops.where}
> max.graph.ops.where(condition, x, y)
Returns `condition ? x : y` (element-wise), where `cond`, `x` and `y`
are input tensors.
**Parameters:**
* condition (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – The condition tensor to use for selecting elementwise
values. This tensor must have a boolean dtype.
* x (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – If the condition is true at a position, the value from the same
position in this tensor will be selected.
* y (Value\[TensorType] | [TensorValue](TensorValue.md#max.graph.TensorValue) | [Shape](shape.md#max.graph.shape.Shape) | [Dim](dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../driver.md#max.driver.DLPackArray)) – If the condition is false at a position, the value from the same
position in this tensor will be selected.
**Returns:**
A new symbolic tensor holding either values from either `x` or `y`,
based on the elements in condition.
### `while_loop()` {#max.graph.ops.while_loop}
> max.graph.ops.while\_loop(initial\_values, predicate, body)
Execute a loop until the predicate evaluates to false.
Both the predicate and body functions must take in as arguments the same
number and types of values as specified in the init\_args. The predication
function must return only a boolean scalar tensor of type `DType.bool`.
The body function must return a list of values matching the types of init\_args,
(or may return a value directly if there is only one).
The following example demonstrates a basic while loop with a single argument:
```python
from max.graph import Graph, ops
from max.dtype import DType
with Graph("while_loop_example") as g:
x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())
def pred(x):
return x < 10
def body(x):
return x + 1
result = ops.while_loop(x, pred, body)
print(result)
```
The following example shows a while loop with multiple arguments:
```python
from max.graph import Graph, ops
from max.dtype import DType
with Graph("while_loop_example") as g:
x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())
y = ops.constant(5, dtype=DType.int32, device=DeviceRef.CPU())
def pred(x, y):
return ops.logical_and(x < 10, y < 15)
def body(x, y):
return [x + 1, y + 1]
results = ops.while_loop((x, y), pred, body)
print(results)
```
**Parameters:**
* initial\_values ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – Initial values for loop arguments. Must be non-empty.
* predicate ([Callable](#max.graph.ops.Callable)\[\[...], [TensorValue](TensorValue.md#max.graph.TensorValue)]) – Callable that takes loop arguments and returns a boolean scalar tensor
of type `DType.bool` determining loop continuation.
* body ([Callable](#max.graph.ops.Callable)\[\[...], [Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Callable that takes loop arguments and returns updated values matching
the types of init\_args.
**Returns:**
List of output values from the final loop iteration.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If init\_args is empty.
* [NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If any init\_arg is a `BufferValue`.
:::note Note
Buffer operations are currently not supported.
:::
---
## quantization
APIs to quantize graph tensors.
This package includes a comprehensive set of tools for working with quantized
models in MAX Graph. It defines supported quantization encodings, configuration
parameters that control the quantization process, and block parameter
specifications for different quantization formats.
The module supports various quantization formats including 4-bit, 5-bit, and
6-bit precision with different encoding schemes. It also provides support for
GGUF-compatible formats for interoperability with other frameworks.
## `BlockParameters` {#max.graph.quantization.BlockParameters}
> class max.graph.quantization.BlockParameters(elements\_per\_block, block\_size)
Parameters describing the structure of a quantization block.
Block-based quantization stores elements in fixed-size blocks.
Each block contains a specific number of elements in a compressed format.
### `block_size` {#max.graph.quantization.BlockParameters.block_size}
> block\_size: [int](https://docs.python.org/3/library/functions.html#int)
### `elements_per_block` {#max.graph.quantization.BlockParameters.elements_per_block}
> elements\_per\_block: [int](https://docs.python.org/3/library/functions.html#int)
## `QuantizationConfig` {#max.graph.quantization.QuantizationConfig}
> class max.graph.quantization.QuantizationConfig(quant\_method, bits, group\_size, desc\_act=False, sym=False)
Configuration for specifying quantization parameters that affect inference.
These parameters control how tensor values are quantized, including the method,
bit precision, grouping, and other characteristics that affect the trade-off
between model size, inference speed, and accuracy.
### `bits` {#max.graph.quantization.QuantizationConfig.bits}
> bits: [int](https://docs.python.org/3/library/functions.html#int)
### `desc_act` {#max.graph.quantization.QuantizationConfig.desc_act}
> desc\_act: [bool](https://docs.python.org/3/library/functions.html#bool) = False
### `group_size` {#max.graph.quantization.QuantizationConfig.group_size}
> group\_size: [int](https://docs.python.org/3/library/functions.html#int)
### `quant_method` {#max.graph.quantization.QuantizationConfig.quant_method}
> quant\_method: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `sym` {#max.graph.quantization.QuantizationConfig.sym}
> sym: [bool](https://docs.python.org/3/library/functions.html#bool) = False
## `QuantizationEncoding` {#max.graph.quantization.QuantizationEncoding}
> class max.graph.quantization.QuantizationEncoding(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Quantization encodings supported by MAX Graph.
Quantization reduces the precision of neural network weights to decrease
memory usage and potentially improve inference speed. Each encoding represents
a different compression method with specific trade-offs between model size,
accuracy, and computational efficiency.
These encodings are commonly used with pre-quantized model checkpoints
(especially GGUF format) or can be applied during weight allocation.
The following example shows how to create a quantized weight using the Q4\_K encoding:
```python
from max.graph.quantization import QuantizationEncoding
from max.graph import Weight
encoding = QuantizationEncoding.Q4_K
quantized_weight = Weight(
name="linear.weight",
dtype=DType.uint8,
shape=[4096, 4096],
device=DeviceRef.GPU(0),
quantization_encoding=encoding
)
```
MAX supports several quantization formats optimized for different use cases.
### `Q4_0` {#max.graph.quantization.QuantizationEncoding.Q4_0}
> Q4\_0
Basic 4-bit quantization with 32 elements per block.
### `Q4_K` {#max.graph.quantization.QuantizationEncoding.Q4_K}
> Q4\_K
4-bit K-quantization with 256 elements per block.
### `Q5_K` {#max.graph.quantization.QuantizationEncoding.Q5_K}
> Q5\_K
5-bit K-quantization with 256 elements per block.
### `Q6_K` {#max.graph.quantization.QuantizationEncoding.Q6_K}
> Q6\_K
6-bit K-quantization with 256 elements per block.
### `GPTQ` {#max.graph.quantization.QuantizationEncoding.GPTQ}
> GPTQ
Group-wise Post-Training Quantization for large language models.
### `block_parameters` {#max.graph.quantization.QuantizationEncoding.block_parameters}
> property block\_parameters: [BlockParameters](#max.graph.quantization.BlockParameters)
Gets the block parameters for this quantization encoding.
**Returns:**
The parameters describing how elements are organized
and encoded in blocks for this quantization encoding.
### `block_size` {#max.graph.quantization.QuantizationEncoding.block_size}
> property block\_size: [int](https://docs.python.org/3/library/functions.html#int)
Number of bytes in encoded representation of block.
All quantization types currently supported by MAX Graph are
block-based: groups of a fixed number of elements are formed, and each
group is quantized together into a fixed-size output block. This value
is the number of bytes resulting after encoding a single block.
### `elements_per_block` {#max.graph.quantization.QuantizationEncoding.elements_per_block}
> property elements\_per\_block: [int](https://docs.python.org/3/library/functions.html#int)
Number of elements per block.
All quantization types currently supported by MAX Graph are
block-based: groups of a fixed number of elements are formed, and each
group is quantized together into a fixed-size output block. This value
is the number of elements gathered into a block.
**Returns:**
Number of original tensor elements in each quantized block.
### `is_gguf` {#max.graph.quantization.QuantizationEncoding.is_gguf}
> property is\_gguf: [bool](https://docs.python.org/3/library/functions.html#bool)
Checks if this quantization encoding is compatible with GGUF format.
GGUF is a format for storing large language models and compatible
quantized weights.
**Returns:**
True if this encoding is compatible with GGUF, False otherwise.
### `name` {#max.graph.quantization.QuantizationEncoding.name}
> property name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Gets the lowercase name of the quantization encoding.
**Returns:**
Lowercase string representation of the quantization encoding.
### `parameters` {#max.graph.shape.Shape.parameters}
> property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[SymbolicDim](dim.md#max.graph.dim.SymbolicDim)]
Lists the symbolic dimension names on which this shape depends.
### `rank` {#max.graph.shape.Shape.rank}
> property rank
### `static_dims` {#max.graph.shape.Shape.static_dims}
> property static\_dims: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Returns all static dims in the shape as a list of integers.
### `to_mlir()` {#max.graph.shape.Shape.to_mlir}
> to\_mlir()
**Return type:**
ShapeAttr
---
## type
Library for graph value types.
## `BufferType` {#max.graph.type.BufferType}
> class max.graph.type.BufferType(dtype, shape, device)
A symbolic buffer type.
This is a reference to a tensor that can be mutated in place.
### `to_mlir()` {#max.graph.type.ConvInputLayout.to_mlir}
> to\_mlir()
Returns an mlir Attribute representing this Layout.
This attribute is used for certain convolution ops.
**Returns:**
An Attribute representing the layout.
**Return type:**
StringAttr
## `DeviceKind` {#max.graph.type.DeviceKind}
> class max.graph.type.DeviceKind(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
A device type representation.
### `CPU` {#max.graph.type.DeviceKind.CPU}
> CPU = 'cpu'
### `GPU` {#max.graph.type.DeviceKind.GPU}
> GPU = 'gpu'
### `from_string()` {#max.graph.type.DeviceKind.from_string}
> static from\_string(txt)
## `DeviceRef` {#max.graph.type.DeviceRef}
> class max.graph.type.DeviceRef(device\_type, id=0)
A symbolic device representation.
DeviceRef type representation consists of a DeviceKind and an id. This is a direct
representation of the device attribute in mlir.
The following example demonstrates how to create and use device references:
```python
from max.graph import DeviceRef
gpu_device = DeviceRef.GPU()
print(gpu_device) # Outputs: gpu:0
# Create a CPU device with specific id
cpu_device = DeviceRef.CPU(id=1)
print(cpu_device) # Outputs: cpu:1
```
**Parameters:**
* device\_type ([DeviceKind](#max.graph.type.DeviceKind))
* id ([int](https://docs.python.org/3/library/functions.html#int))
### `CPU()` {#max.graph.type.DeviceRef.CPU}
> static CPU(id=0)
Static Method for creating a CPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
**Return type:**
[DeviceRef](#max.graph.type.DeviceRef)
### `GPU()` {#max.graph.type.DeviceRef.GPU}
> static GPU(id=0)
Static Method for creating a GPU device.
**Parameters:**
id ([int](https://docs.python.org/3/library/functions.html#int))
### `from_mlir()` {#max.graph.type.DeviceRef.from_mlir}
> static from\_mlir(attr)
Returns a device from an mlir attribute
**Parameters:**
attr (DeviceRefAttr)
**Return type:**
[DeviceRef](#max.graph.type.DeviceRef)
### `id` {#max.graph.type.DeviceRef.id}
> id: [int](https://docs.python.org/3/library/functions.html#int)
### `is_cpu()` {#max.graph.type.DeviceRef.is_cpu}
> is\_cpu()
Returns true if the device is a CPU device.
attr (LayoutAttr) – The MLIR Attribute object to parse into a layout.
**Returns:**
The FilterLayout represented by the Attribute value.
**Return type:**
[FilterLayout](#max.graph.type.FilterLayout)
### `to_mlir()` {#max.graph.type.FilterLayout.to_mlir}
> to\_mlir()
Returns an mlir Attribute representing this Layout.
This attribute is used in tensor type metadata for certain ops.
**Returns:**
An Attribute representing the layout.
**Return type:**
LayoutAttr
## `TensorType` {#max.graph.type.TensorType}
> class max.graph.type.TensorType(dtype, shape, device, \_layout=None)
A symbolic [`TensorType`](#max.graph.type.TensorType).
This is not an eager tensor type! This contains no actual data, but
instead represents the type of a value at some point in time during model
execution.
Most internal values in a model will be tensors. This type represents
their element type (`dtype`) and dimensions (`dims`) at a specific point during
model computation. It allows us to do some optimistic optimizations and
shape inference during graph construction, and to provide more detailed
shape information to the compiler for further optimization passes.
The following example shows how to create a tensor type with static dimensions and access its properties:
```python
from max.graph import TensorType
from max.dtype import DType
# Create a tensor type with float32 elements and static dimensions 2x3
tensor_type = TensorType(DType.float32, (2, 3))
print(tensor_type.dtype) # Outputs: DType.float32
print(tensor_type.shape) # Outputs: [2, 3]
```
It can also represent a fully dynamic rank tensor. The presence of dynamic
rank tensors in a graph will often degrade performance dramatically and
prevents many classes of optimizations.
An optional device (`device`) can also be provided to indicate the explicit
device the tensor is associated with.
### `as_buffer()` {#max.graph.type.TensorType.as_buffer}
> as\_buffer()
Returns the analogous buffer type.
**Return type:**
[BufferType](#max.graph.type.BufferType)
### `from_mlir()` {#max.graph.type.TensorType.from_mlir}
> classmethod from\_mlir(type)
Constructs a tensor type from an MLIR type.
**Parameters:**
* t – The MLIR Type object to parse into a tensor type.
* type (TensorType)
**Returns:**
The tensor type represented by the MLIR Type value.
**Return type:**
[TensorType](#max.graph.type.TensorType)
### `to_mlir()` {#max.graph.type.TensorType.to_mlir}
> to\_mlir()
Converts to an `mlir.Type` instance.
**Returns:**
An `mlir.Type` in the specified Context.
**Return type:**
TensorType
## `Type` {#max.graph.type.Type}
> class max.graph.type.Type
Represents any possible type for Graph values.
Every Value in the Graph has a Type, and that type is represented by an Type.
This type may be inspected to get finer-grained types and learn more
about an individual Value.
The following example shows how to work with types in a graph:
```python
from max.graph import Graph, TensorType
from max.dtype import DType
with Graph() as g:
# Create a tensor constant with a specific type
tensor_type = TensorType(DType.float32, [2, 3])
# The type can be inspected to get information about the value
print(f"Tensor element type: {tensor_type.dtype}") # Outputs: DType.float32
print(f"Tensor shape: {tensor_type.shape}") # Outputs: [2, 3]
```
### `from_mlir()` {#max.graph.type.Type.from_mlir}
> static from\_mlir(t)
Constructs a type from an MLIR type.
**Parameters:**
t (MlirType) – The MLIR Type object to parse into a type.
### `to_mlir()` {#max.graph.type.Type.to_mlir}
> to\_mlir()
Converts to an `mlir.Type` instance.
**Returns:**
An `mlir.Type` in the specified Context.
**Return type:**
MlirType
---
## weights
Weights are the learned parameters that store a neural network’s knowledge.
They’re multi-dimensional arrays (tensors) of numerical values that determine how
the model transforms inputs into outputs. These weights contain all the
information needed for a model to perform its task - whether that’s text
generation, image classification, or any other capability.
## `GGUFWeights` {#max.graph.weights.GGUFWeights}
> class max.graph.weights.GGUFWeights(source, tensors=None, prefix='', allocated=None)
Implementation for loading weights from GGUF (GPT-Generated Unified Format) files.
`GGUFWeights` provides an interface to load model weights from GGUF files,
which are optimized for quantized large language models. GGUF is the
successor to GGML format and is commonly used in the `llama.cpp` ecosystem
for efficient storage and loading of quantized models.
```python
from pathlib import Path
from max.graph.weights import GGUFWeights
from max.dtype import DType
from max.graph.quantization import QuantizationEncoding
gguf_path = Path("model-q4_k.gguf")
weights = GGUFWeights(gguf_path)
# Check if a weight exists
if weights.model.layers[0].attention.wq.exists():
# Allocate quantized attention weight
wq_weight = weights.model.layers[0].attention.wq.allocate(
dtype=DType.uint8, # GGUF quantized weights use uint8
device=DeviceRef.CPU()
)
# Access weight data with quantization info
weight_data = weights.model.layers[0].attention.wq.data()
print(f"Quantization: {weight_data.quantization_encoding}")
print(f"Shape: {weight_data.shape}")
# Allocate with quantization validation
ffn_weight = weights.model.layers[0].feed_forward.w1.allocate(
quantization_encoding=QuantizationEncoding.Q4_K,
device=DeviceRef.GPU(0)
)
# Iterate through all weights in a layer
for name, weight in weights.model.layers[0].items():
if weight.exists():
print(f"Found weight: {name}")
```
### `allocate()` {#max.graph.weights.GGUFWeights.allocate}
> allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0)
Creates and optionally validates a new Weight.
### `allocated_weights` {#max.graph.weights.GGUFWeights.allocated_weights}
> property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)]
Gets the values of all weights that were allocated previously.
### `data()` {#max.graph.weights.GGUFWeights.data}
> data()
Get weight data with metadata.
```python
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
```
**Returns:**
A WeightData object containing the tensor data along with
metadata like name, dtype, shape, and quantization encoding.
**Raises:**
[KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If no weight exists at the current hierarchical name.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `exists()` {#max.graph.weights.GGUFWeights.exists}
> exists()
Check if a weight with this exact name exists.
```python
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
```
**Returns:**
True if a weight with the current hierarchical name exists
in the loaded weights, False otherwise.
### `items()` {#max.graph.weights.GGUFWeights.items}
> items()
Iterate through all allocable weights that start with the prefix.
### `name` {#max.graph.weights.GGUFWeights.name}
> property name: [str](https://docs.python.org/3/library/stdtypes.html#str)
The current weight name or prefix.
## `SafetensorWeights` {#max.graph.weights.SafetensorWeights}
> class max.graph.weights.SafetensorWeights(filepaths, \*, tensors=None, tensors\_to\_file\_idx=None, prefix='', allocated=None, \_st\_weight\_map=None, \_st\_file\_handles=None)
Implementation for loading weights from safetensors files.
SafetensorWeights provides a secure and efficient way to load model weights
from safetensors format files. Safetensors is designed by Hugging Face for
safe serialization that prevents arbitrary code execution and supports
memory-mapped loading for fast access.
```python
from pathlib import Path
from max.graph.weights import SafetensorWeights
from max.dtype import DType
# Load weights from safetensors files
weight_files = [Path("model.safetensors")]
weights = SafetensorWeights(weight_files)
# Check if a weight exists
if weights.model.embeddings.weight.exists():
# Allocate the embedding weight
embedding_weight = weights.model.embeddings.weight.allocate(
dtype=DType.float32,
device=DeviceRef.CPU()
)
# Access weights with hierarchical naming
attn_weight = weights.transformer.layers[0].attention.weight.allocate(
dtype=DType.float16
)
```
### `allocate()` {#max.graph.weights.SafetensorWeights.allocate}
> allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0)
Creates a Weight that can be added to a graph.
### `allocate_as_bytes()` {#max.graph.weights.SafetensorWeights.allocate_as_bytes}
> allocate\_as\_bytes(dtype=None)
Create a Weight that can be added to the graph. Has a uint8
representation, instead of the original data type. Last dimension of
the scale gets scaled by number of bytes it takes to represent the
original data type. For example, \[512, 256] float32 weights become
\[512, 1024] uint8 weights. Scalar weights will be interpreted as
weights with shape \[1].
### `allocated_weights` {#max.graph.weights.SafetensorWeights.allocated_weights}
> property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)]
Gets the values of all weights that were allocated previously.
### `data()` {#max.graph.weights.SafetensorWeights.data}
> data()
Get weight data with metadata.
```python
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
```
**Returns:**
A WeightData object containing the tensor data along with
metadata like name, dtype, shape, and quantization encoding.
**Raises:**
[KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If no weight exists at the current hierarchical name.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `exists()` {#max.graph.weights.SafetensorWeights.exists}
> exists()
Check if a weight with this exact name exists.
```python
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
```
**Returns:**
True if a weight with the current hierarchical name exists
in the loaded weights, False otherwise.
### `items()` {#max.graph.weights.SafetensorWeights.items}
> items()
Iterate through all allocable weights that start with the prefix.
### `name` {#max.graph.weights.SafetensorWeights.name}
> property name: [str](https://docs.python.org/3/library/stdtypes.html#str)
The current weight name or prefix.
## `WeightData` {#max.graph.weights.WeightData}
> class max.graph.weights.WeightData(data, name, dtype, shape, quantization\_encoding=None)
Container for weight tensor data with metadata.
`WeightData` encapsulates a weight tensor along with its metadata,
providing utilities for type conversion and format compatibility.
It supports the DLPack protocol for efficient tensor sharing between
frameworks.
**Parameters:**
* data ([DLPackArray](../driver.md#max.driver.DLPackArray))
* name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* dtype ([DType](../dtype.md#max.dtype.DType))
* shape ([Shape](shape.md#max.graph.shape.Shape))
* quantization\_encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | None)
### `astype()` {#max.graph.weights.WeightData.astype}
> astype(dtype)
Convert the weight data to a different dtype.
This method performs actual data conversion, unlike `view()` which
reinterprets the underlying bytes. Special handling is provided for
bfloat16 conversions using PyTorch when available.
```python
# Convert float32 weights to float16 for reduced memory
weight_data = weights.model.layer.weight.data()
fp16_data = weight_data.astype(DType.float16)
```
**Parameters:**
dtype ([DType](../dtype.md#max.dtype.DType)) – Target data type for conversion.
**Returns:**
A new WeightData instance with the converted data.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `data` {#max.graph.weights.WeightData.data}
> data: [DLPackArray](../driver.md#max.driver.DLPackArray)
The weight tensor as a DLPack array.
### `dtype` {#max.graph.weights.WeightData.dtype}
> dtype: [DType](../dtype.md#max.dtype.DType)
Data type of the tensor (for example, `DType.float32`, `DType.uint8`).
### `from_numpy()` {#max.graph.weights.WeightData.from_numpy}
> classmethod from\_numpy(arr, name)
Create WeightData from a numpy array.
**Parameters:**
* arr ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – Numpy array containing the weight data.
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Name to assign to this weight.
**Returns:**
A new WeightData instance with dtype and shape inferred
from the numpy array.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `name` {#max.graph.weights.WeightData.name}
> name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Hierarchical name of the weight (for example, `"model.layers.0.weight"`).
### `quantization_encoding` {#max.graph.weights.WeightData.quantization_encoding}
> quantization\_encoding: [QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) = None
Optional quantization scheme applied to the weight.
### `shape` {#max.graph.weights.WeightData.shape}
> shape: [Shape](shape.md#max.graph.shape.Shape)
Shape of the tensor as a Shape object.
## `Weights` {#max.graph.weights.Weights}
> class max.graph.weights.Weights(\*args, \*\*kwargs)
Protocol for managing and accessing model weights hierarchically.
The Weights protocol provides a convenient interface for loading and organizing
neural network weights. It supports hierarchical naming through attribute and
index access, making it easy to work with complex model architectures.
Weights in MAX are tensors backed by external memory (buffers or memory-mapped
files) that remain separate from the compiled graph.
```python
from max.graph import Graph
from max.dtype import DType
# Create a graph and get its weights interface
graph = Graph("my_model")
weights = graph.weights()
# Allocate weights with hierarchical naming
attn_weight = weights.transformer.layers[0].attention.weight.allocate(
dtype=DType.float32,
shape=(768, 768)
)
# Creates weight named "transformer.layers.0.attention.weight"
# Check if a weight exists before allocating
if weights.transformer.layers[0].mlp.weight.exists():
mlp_weight = weights.transformer.layers[0].mlp.weight.allocate(
dtype=DType.float16,
shape=(768, 3072)
)
```
### `allocate()` {#max.graph.weights.Weights.allocate}
> allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0)
Create a Weight object for this tensor.
```python
# Allocate a weight with specific configuration
weight = weights.model.layers[0].weight.allocate(
dtype=DType.float16, # Convert to half precision
shape=(768, 768),
device=DeviceRef.GPU(0) # Place on first GPU
)
# Add to graph
with graph:
weight_tensor = graph.add_weight(weight)
```
**Parameters:**
* dtype ([DType](../dtype.md#max.dtype.DType) | None) – Data type for the weight. If `None`, uses the original dtype.
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) – Shape of the weight tensor. If `None`, uses the original shape.
* quantization\_encoding ([QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | None) – Quantization scheme to apply (for example, `Q4_K`, `Q8_0`).
* device ([DeviceRef](type.md#max.graph.type.DeviceRef)) – Target device for the weight (CPU or GPU).
**Returns:**
A Weight object that can be added to a graph using
`graph.add_weight()`.
**Return type:**
[Weight](Weight.md#max.graph.Weight)
### `allocated_weights` {#max.graph.weights.Weights.allocated_weights}
> property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)]
Get all previously allocated weights. This only includes weights that were explicitly allocated
: using the [`allocate()`](#max.graph.weights.Weights.allocate) method, not all available weights.
**Returns:**
A dictionary mapping weight names to their numpy arrays for
all weights that have been allocated through this interface.
### `data()` {#max.graph.weights.Weights.data}
> data()
Get weight data with metadata.
```python
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
```
**Returns:**
A WeightData object containing the tensor data along with
metadata like name, dtype, shape, and quantization encoding.
**Raises:**
[KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If no weight exists at the current hierarchical name.
**Return type:**
[WeightData](#max.graph.weights.WeightData)
### `exists()` {#max.graph.weights.Weights.exists}
> exists()
Check if a weight with this exact name exists.
```python
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
```
**Returns:**
True if a weight with the current hierarchical name exists
in the loaded weights, False otherwise.
### `items()` {#max.graph.weights.Weights.items}
> items()
Iterate through all weights that start with the current prefix.
```python
# Iterate through all weights in a specific layer
for name, weight in weights.transformer.layers[0].items():
print(f"Found weight: {name}")
```
**Yields:**
Tuples of (name, weight\_accessor) for each weight under the
current prefix. The name is relative to the current prefix.
### `name` {#max.graph.weights.Weights.name}
> property name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Get the current weight name or prefix.
**Returns:**
The hierarchical name built from attribute and index access.
For example, if accessed as `weights.model.layers[0]`,
returns “model.layers.0”.
## `WeightsFormat` {#max.graph.weights.WeightsFormat}
> class max.graph.weights.WeightsFormat(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Enumeration of supported weight file formats.
MAX supports multiple weight formats to accommodate different model sources
and use cases.
### `gguf` {#max.graph.weights.WeightsFormat.gguf}
> gguf = 'gguf'
GGUF (GPT-Generated Unified Format) for quantized models.
File extension: `.gguf`
Optimized for quantized large language models, particularly those from the
llama.cpp ecosystem. Supports multiple quantization schemes (`Q4_K`,
`Q5_K`, `Q8_0`, etc.) and includes model metadata in the file.
### `safetensors` {#max.graph.weights.WeightsFormat.safetensors}
> safetensors = 'safetensors'
Safetensors format for secure and efficient tensor storage.
File extension: `.safetensors`
Designed by Hugging Face for safe serialization that prevents
arbitrary code execution. Uses memory-mapped files for fast loading
and supports sharding across multiple files.
## `load_weights()` {#max.graph.weights.load_weights}
> max.graph.weights.load\_weights(paths)
Loads neural network weights from checkpoint files.
Automatically detects checkpoint formats based on file extensions and returns
the appropriate Weights implementation, creating a seamless interface for
loading weights from different formats.
Supported formats:
* Safetensors: .safetensors
* PyTorch: .bin, .pt, .pth
* GGUF: .gguf
The following example shows how to load weights from a Safetensors file:
```python
from pathlib import Path
from max.graph.weights import load_weights
# Load multi-file checkpoints
sharded_paths = [
Path("model-00001-of-00003.safetensors"),
Path("model-00002-of-00003.safetensors"),
Path("model-00003-of-00003.safetensors")
]
weights = load_weights(sharded_paths)
layer_weight = weights.model.layers[23].mlp.gate_proj.weight.allocate(
dtype=DType.float32,
shape=[4096, 14336],
device=DeviceRef.GPU(0)
)
```
**Parameters:**
paths ([list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]) – List of pathlib.Path objects pointing to checkpoint files.
For multi-file checkpoints (e.g., sharded Safetensors), provide
all file paths in the list. For single-file checkpoints, provide
a list with one path.
**Return type:**
[Weights](#max.graph.weights.Weights)
## `weights_format()` {#max.graph.weights.weights_format}
> max.graph.weights.weights\_format(weight\_paths)
Detect the format of weight files based on their extensions.
This function examines the file extensions of all provided paths to
determine the weight format. All files must have the same format;
mixed formats are not supported.
```python
from pathlib import Path
# Detect format for safetensor files
paths = [Path("model-00001.safetensors"), Path("model-00002.safetensors")]
format = weights_format(paths)
print(format) # WeightsFormat.safetensors
```
**Parameters:**
weight\_paths ([list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]) – List of file paths containing model weights. All files
must have the same extension/format.
**Returns:**
The detected WeightsFormat enum value.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If weight\_paths is empty, contains mixed formats, or
has unsupported file extensions.
**Return type:**
[WeightsFormat](#max.graph.weights.WeightsFormat)
---
## max
The MAX Python API reference.
The MAX API provides a high-performance graph compiler and runtime library that
executes AI models with incredible speed on a wide range of hardware.
MAX offers a layered architecture that lets you work at the level of abstraction
that best fits your needs. From deploying production-ready models with a few
lines of code to building custom neural networks from scratch, each layer builds
upon the others so you can move between levels seamlessly as requirements evolve.
For an introduction, see the
[Model developer guide](/max/develop/).
## Packages and modules
* [`diagnostics.gpu`](/max/api/python/diagnostics/gpu): GPU monitoring and performance diagnostics utilities.
* [`driver`](/max/api/python/driver): Low-level device management and tensor operations.
* [`dtype`](/max/api/python/dtype): Unified data type system supporting various numeric formats.
* [`engine`](/max/api/python/engine): Model execution runtime with automatic optimization.
* [`entrypoints`](/max/api/python/entrypoints): Command-line tools and serving infrastructure.
* [`functional`](/max/api/python/functional): Functional tensor operations (relu, softmax, etc.).
* [`graph`](/max/api/python/graph): Computational graph construction with 100+ operations for complete model control.
* [`interfaces`](/max/api/python/interfaces): Universal interfaces for consistent API integration.
* [`kv_cache`](/max/api/python/kv_cache): KV cache management for efficient attention computation.
* [`nn`](/max/api/python/nn): High-level neural network building blocks with automatic graph compilation.
* [`pipelines`](/max/api/python/pipelines): Pre-built, optimized model architectures for immediate deployment.
* [`profiler`](/max/api/python/profiler): Performance profiling and tracing utilities.
* [`random`](/max/api/python/random): Random tensor generation utilities.
* [`tensor`](/max/api/python/tensor): Tensor class with eager execution.
* [`torch`](/max/api/python/torch): PyTorch integration for custom operations and interoperability.
---
## interfaces
Universal interfaces between all aspects of the MAX Inference Stack.
## `AudioGenerationInputs` {#max.interfaces.AudioGenerationInputs}
> class max.interfaces.AudioGenerationInputs(batch)
Input data structure for audio generation pipelines.
This class represents the input data required for audio generation operations
within the pipeline framework. It extends PipelineInputs and provides type-safe
generic support for different audio generation context types.
### `batch` {#max.interfaces.AudioGenerationInputs.batch}
> batch: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), AudioGenerationContextType]
A dictionary mapping RequestID to AudioGenerationContextType instances.
This batch structure allows for processing multiple audio generation
requests simultaneously while maintaining request-specific context
and configuration data.
## `AudioGenerationMetadata` {#max.interfaces.AudioGenerationMetadata}
> class max.interfaces.AudioGenerationMetadata(\*, sample\_rate=None, duration=None, chunk\_id=None, timestamp=None, final\_chunk=None, model\_name=None, request\_id=None, tokens\_generated=None, processing\_time=None, echo=None)
Represents metadata associated with audio generation.
This class will eventually replace the metadata dictionary used throughout
the AudioGenerationOutput object, providing a structured and type-safe
alternative for audio generation metadata.
**Parameters:**
* sample\_rate ([int](https://docs.python.org/3/library/functions.html#int) | None) – The sample rate of the generated audio in Hz.
* duration ([float](https://docs.python.org/3/library/functions.html#float) | None) – The duration of the generated audio in seconds.
* chunk\_id ([int](https://docs.python.org/3/library/functions.html#int) | None) – Identifier for the audio chunk (useful for streaming).
* timestamp ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Timestamp when the audio was generated.
* final\_chunk ([bool](https://docs.python.org/3/library/functions.html#bool) | None) – Whether this is the final chunk in a streaming sequence.
* model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Name of the model used for generation.
* request\_id ([RequestID](#max.interfaces.RequestID) | None) – Unique identifier for the generation request.
* tokens\_generated ([int](https://docs.python.org/3/library/functions.html#int) | None) – Number of tokens generated for this audio.
* processing\_time ([float](https://docs.python.org/3/library/functions.html#float) | None) – Time taken to process this audio chunk in seconds.
* echo ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Echo of the input prompt or identifier for verification.
### `tokens_generated` {#max.interfaces.AudioGenerationMetadata.tokens_generated}
> tokens\_generated: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
## `AudioGenerationOutput` {#max.interfaces.AudioGenerationOutput}
> class max.interfaces.AudioGenerationOutput(final\_status, steps\_executed, audio\_data=\, buffer\_speech\_tokens=None, metadata=\)
Represents a response from the audio generation API.
This class encapsulates the result of an audio generation request, including
the final status, generated audio data, and optional buffered speech tokens.
### `audio_prompt_tokens` {#max.interfaces.AudioGenerationRequest.audio_prompt_tokens}
> audio\_prompt\_tokens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
The prompt speech IDs to use for audio generation.
### `audio_prompt_transcription` {#max.interfaces.AudioGenerationRequest.audio_prompt_transcription}
> audio\_prompt\_transcription: [str](https://docs.python.org/3/library/stdtypes.html#str) = ''
The audio prompt transcription to use for audio generation.
### `buffer_speech_tokens` {#max.interfaces.AudioGenerationRequest.buffer_speech_tokens}
> buffer\_speech\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | [None](https://docs.python.org/3/library/constants.html#None) = None
An optional field potentially containing the last N speech tokens
generated by the model from a previous request.
When this field is specified, this tensor is used to buffer the tokens sent
to the audio decoder.
### `input` {#max.interfaces.AudioGenerationRequest.input}
> input: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None
The text to generate audio for. The maximum length is 4096 characters.
### `model` {#max.interfaces.AudioGenerationRequest.model}
> model: [str](https://docs.python.org/3/library/stdtypes.html#str)
The name of the model to be used for generating audio chunks. This should match
the available models on the server and determines the behavior and
capabilities of the response generation.
### `prompt` {#max.interfaces.AudioGenerationRequest.prompt}
> prompt: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None
Optionally provide a preprocessed list of token ids or a prompt string to pass as input directly into the model.
This replaces automatically generating TokenGeneratorRequestMessages given the input, audio prompt tokens,
audio prompt transcription fields.
### `sampling_params` {#max.interfaces.AudioGenerationRequest.sampling_params}
> sampling\_params: [SamplingParams](#max.interfaces.SamplingParams)
Request sampling configuration options.
### `streaming` {#max.interfaces.AudioGenerationRequest.streaming}
> streaming: [bool](https://docs.python.org/3/library/functions.html#bool) = True
Whether to stream the audio generation.
## `BaseContext` {#max.interfaces.BaseContext}
> class max.interfaces.BaseContext(\*args, \*\*kwargs)
Core interface for request lifecycle management across all of MAX, including serving, scheduling, and pipelines.
This protocol is intended to provide a unified, minimal contract for request state and status handling throughout the MAX stack.
Each pipeline variant (e.g., text generation, embeddings, image generation) is expected to extend this interface by creating
their own modality-specific context classes that implement this protocol and add additional functionality relevant to their
particular use case.
The minimal interface ensures that all context types can be handled uniformly by the scheduling and serving infrastructure,
while allowing pipeline-specific implementations to add their own state management, input validation, and result handling.
### `is_done` {#max.interfaces.BaseContext.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
Whether the request has completed generation.
### `request_id` {#max.interfaces.BaseContext.request_id}
> property request\_id: [RequestID](#max.interfaces.RequestID)
Unique identifier for the request.
### `status` {#max.interfaces.BaseContext.status}
> property status: [GenerationStatus](#max.interfaces.GenerationStatus)
Current generation status of the request.
## `BatchProcessorInputs` {#max.interfaces.BatchProcessorInputs}
> class max.interfaces.BatchProcessorInputs(logits, logit\_offsets, context\_batch)
Arguments for a batch logits processor.
* logits: The model logits, a float32 tensor with shape (N\_batch, vocab\_size).
N\_batch is the number of logits returned by the model for each sequence in the batch.
* logit\_offsets: If the model returns multiple logits, this is a tensor with
shape (batch\_size + 1, 1) that contains the offsets of each sequence in
the batch. Otherwise, this is None.
* context\_batch: The batch of contexts containing the inputs to the model.
### `context_batch` {#max.interfaces.BatchProcessorInputs.context_batch}
> context\_batch: Sequence\[[TextGenerationContext](#max.interfaces.TextGenerationContext)]
### `logit_offsets` {#max.interfaces.BatchProcessorInputs.logit_offsets}
> logit\_offsets: md.Buffer | [None](https://docs.python.org/3/library/constants.html#None)
### `logits` {#max.interfaces.BatchProcessorInputs.logits}
> logits: md.Buffer
## `BatchType` {#max.interfaces.BatchType}
> class max.interfaces.BatchType(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Type of batch.
### `CE` {#max.interfaces.BatchType.CE}
> CE = 'CE'
Context encoding batch.
### `TG` {#max.interfaces.BatchType.TG}
> TG = 'TG'
Token generation batch.
## `EmbeddingsContext` {#max.interfaces.EmbeddingsContext}
> class max.interfaces.EmbeddingsContext(\*args, \*\*kwargs)
Protocol defining the interface for embeddings generation contexts.
An `EmbeddingsContext` represents model inputs for embeddings generation pipelines,
managing the state and parameters needed for generating embeddings from input text.
Unlike text generation contexts, this focuses on single-step embedding generation
without iterative token generation concerns.
This protocol includes only the fields necessary for embeddings generation,
excluding text generation specific features like:
* End-of-sequence token handling (eos\_token\_ids)
* Grammar matchers for structured output (matcher)
* JSON schema constraints (json\_schema)
* Log probability tracking (log\_probabilities)
* Token generation iteration state
### `model_name` {#max.interfaces.EmbeddingsContext.model_name}
> property model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str)
The name of the embeddings model to use.
**Returns:**
A string identifying the specific embeddings model for this request.
### `tokens` {#max.interfaces.EmbeddingsContext.tokens}
> property tokens: [TokenBuffer](#max.interfaces.TokenBuffer)
The input tokens to be embedded.
**Returns:**
A NumPy array of token IDs representing the input text to generate
embeddings for.
## `EmbeddingsGenerationInputs` {#max.interfaces.EmbeddingsGenerationInputs}
> class max.interfaces.EmbeddingsGenerationInputs(batches: 'list\[dict\[RequestID, EmbeddingsContext]]')
embeddings ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – The generated embeddings as a NumPy array.
### `embeddings` {#max.interfaces.EmbeddingsGenerationOutput.embeddings}
> embeddings: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
The generated embeddings as a NumPy array.
### `is_done` {#max.interfaces.EmbeddingsGenerationOutput.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
Indicates whether the embedding generation process is complete.
**Returns:**
Always True, as embedding generation is a single-step operation.
## `GenerationOutput` {#max.interfaces.GenerationOutput}
> class max.interfaces.GenerationOutput(\*, request\_id, final\_status, output)
Output container for image generation pipeline operations.
This class holds a list of generated images in OpenResponses API format,
along with request tracking and status information. It implements the
PipelineOutput protocol by providing the required is\_done property.
Example:
```python
import numpy as np
from max.interfaces.generation import GenerationOutput
from max.interfaces.request import RequestID
from max.interfaces.request.open_responses import OutputImageContent
from max.interfaces.status import GenerationStatus
img_array1 = np.random.rand(512, 512, 3).astype(np.float32)
img_array2 = np.random.rand(512, 512, 3).astype(np.float32)
result = GenerationOutput(
request_id=RequestID(value="req-123"),
final_status=GenerationStatus.END_OF_SEQUENCE,
output=[
OutputImageContent.from_numpy(img_array1, format="png"),
OutputImageContent.from_numpy(img_array2, format="jpeg"),
]
)
# Or create from URLs
result_from_urls = GenerationOutput(
request_id=RequestID(value="req-456"),
final_status=GenerationStatus.END_OF_SEQUENCE,
output=[
OutputImageContent(
type="output_image",
image_url="https://example.com/image1.png",
format="png"
)
]
)
# Check if generation is complete
if result.is_done:
print(f"Generated {len(result.output)} images")
```
### `final_status` {#max.interfaces.GenerationOutput.final_status}
> final\_status: [GenerationStatus](#max.interfaces.GenerationStatus)
The final status of the generation process.
### `is_done` {#max.interfaces.GenerationOutput.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
Indicates whether the pipeline operation has completed.
**Returns:**
True if the generation is done (status is not ACTIVE),
False otherwise.
### `output` {#max.interfaces.GenerationOutput.output}
> output: [list](https://docs.python.org/3/library/stdtypes.html#list)\[OutputImageContent]
List of OutputImageContent objects representing generated images.
### `request_id` {#max.interfaces.GenerationOutput.request_id}
> request\_id: [RequestID](#max.interfaces.RequestID)
The unique identifier for the generation request.
## `GenerationStatus` {#max.interfaces.GenerationStatus}
> class max.interfaces.GenerationStatus(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Enum representing the status of a generation process in the MAX API.
### `ACTIVE` {#max.interfaces.GenerationStatus.ACTIVE}
> ACTIVE = 'active'
The generation process is ongoing.
### `CANCELLED` {#max.interfaces.GenerationStatus.CANCELLED}
> CANCELLED = 'cancelled'
The generation process has been cancelled by the user.
### `END_OF_SEQUENCE` {#max.interfaces.GenerationStatus.END_OF_SEQUENCE}
> END\_OF\_SEQUENCE = 'end\_of\_sequence'
The generation process has reached the end of the sequence.
### `MAXIMUM_LENGTH` {#max.interfaces.GenerationStatus.MAXIMUM_LENGTH}
> MAXIMUM\_LENGTH = 'maximum\_length'
The generation process has reached the maximum allowed length.
### `is_done` {#max.interfaces.GenerationStatus.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns True if the generation process is complete (not ACTIVE).
**Returns:**
True if the status is not ACTIVE, indicating completion.
## `ImageContentPart` {#max.interfaces.ImageContentPart}
> class max.interfaces.ImageContentPart(\*, type='image')
**Parameters:**
type ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['image'])
### `model_config` {#max.interfaces.ImageContentPart.model_config}
> model\_config: ClassVar\[ConfigDict] = {'frozen': True}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
### `type` {#max.interfaces.ImageContentPart.type}
> type: Literal\['image']
## `ImageMetadata` {#max.interfaces.ImageMetadata}
> class max.interfaces.ImageMetadata(\*, start\_idx, end\_idx, pixel\_values, image\_hash=None)
Metadata about an image in the prompt.
Each image corresponds to a range in the text token array \[start\_idx, end\_idx).
### `end_idx` {#max.interfaces.ImageMetadata.end_idx}
> end\_idx: [int](https://docs.python.org/3/library/functions.html#int)
One after the index of the last \ special token for the image
### `image_hash` {#max.interfaces.ImageMetadata.image_hash}
> image\_hash: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
Hash of the image, for use in prefix caching
### `pixel_values` {#max.interfaces.ImageMetadata.pixel_values}
> pixel\_values: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
Pixel values for the image.
Can be various dtypes depending on the vision model:
* float32: Original precision
* uint16: BFloat16 bits stored as uint16 (workaround for NumPy’s lack of
native bfloat16 support). Reinterpreted as bfloat16 on GPU.
### `start_idx` {#max.interfaces.ImageMetadata.start_idx}
> start\_idx: [int](https://docs.python.org/3/library/functions.html#int)
Index of the first \ special token for the image
## `LoRAOperation` {#max.interfaces.LoRAOperation}
> class max.interfaces.LoRAOperation(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Enum for different LoRA operations.
### `LOAD` {#max.interfaces.LoRAOperation.LOAD}
> LOAD = 'load'
### `UNLOAD` {#max.interfaces.LoRAOperation.UNLOAD}
> UNLOAD = 'unload'
## `LoRARequest` {#max.interfaces.LoRARequest}
> class max.interfaces.LoRARequest(operation, lora\_name, lora\_path=None)
Container for LoRA adapter requests.
* status ([LoRAStatus](#max.interfaces.LoRAStatus))
* message ([str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)])
### `message` {#max.interfaces.LoRAResponse.message}
> message: [str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]
### `status` {#max.interfaces.LoRAResponse.status}
> status: [LoRAStatus](#max.interfaces.LoRAStatus)
## `LoRAStatus` {#max.interfaces.LoRAStatus}
> class max.interfaces.LoRAStatus(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Enum for LoRA operation status.
### `LOAD_ERROR` {#max.interfaces.LoRAStatus.LOAD_ERROR}
> LOAD\_ERROR = 'load\_error'
### `LOAD_INVALID_ADAPTER` {#max.interfaces.LoRAStatus.LOAD_INVALID_ADAPTER}
> LOAD\_INVALID\_ADAPTER = 'load\_invalid\_adapter'
### `LOAD_INVALID_PATH` {#max.interfaces.LoRAStatus.LOAD_INVALID_PATH}
> LOAD\_INVALID\_PATH = 'load\_invalid\_path'
### `LOAD_NAME_EXISTS` {#max.interfaces.LoRAStatus.LOAD_NAME_EXISTS}
> LOAD\_NAME\_EXISTS = 'load\_name\_exists'
### `SUCCESS` {#max.interfaces.LoRAStatus.SUCCESS}
> SUCCESS = 'success'
### `UNLOAD_ERROR` {#max.interfaces.LoRAStatus.UNLOAD_ERROR}
> UNLOAD\_ERROR = 'unload\_error'
### `UNLOAD_NAME_NONEXISTENT` {#max.interfaces.LoRAStatus.UNLOAD_NAME_NONEXISTENT}
> UNLOAD\_NAME\_NONEXISTENT = 'unload\_name\_nonexistent'
### `UNSPECIFIED_ERROR` {#max.interfaces.LoRAStatus.UNSPECIFIED_ERROR}
> UNSPECIFIED\_ERROR = 'unspecified\_error'
## `LoRAType` {#max.interfaces.LoRAType}
> class max.interfaces.LoRAType(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Enumeration for LoRA Types.
### `A` {#max.interfaces.LoRAType.A}
> A = 'lora\_A'
Represents the LoRA A matrix (high rank tensor to low rank tensor).
### `B` {#max.interfaces.LoRAType.B}
> B = 'lora\_B'
Represents the LoRA B matrix. (low rank tensor to high rank tensor)
### `BIAS` {#max.interfaces.LoRAType.BIAS}
> BIAS = 'lora.bias'
Represents the LoRA bias matrix. (added to matrix B)
### `B_KV` {#max.interfaces.LoRAType.B_KV}
> B\_KV = 'lora\_B\_kv'
Represents the combined K and V LoRA B matrices for QKV fusion.
## `LogProbabilities` {#max.interfaces.LogProbabilities}
> class max.interfaces.LogProbabilities(token\_log\_probabilities, top\_log\_probabilities)
Log probabilities for an individual output token.
This is a data-only class that serves as a serializable data structure for
transferring log probability information. It does not provide any functionality
for calculating or manipulating log probabilities - it is purely for data storage
and serialization purposes.
### `token_log_probabilities` {#max.interfaces.LogProbabilities.token_log_probabilities}
> token\_log\_probabilities: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)]
Probabilities of each token.
### `top_log_probabilities` {#max.interfaces.LogProbabilities.top_log_probabilities}
> top\_log\_probabilities: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [float](https://docs.python.org/3/library/functions.html#float)]]
Top tokens and their corresponding probabilities.
## `MAXPullQueue` {#max.interfaces.MAXPullQueue}
> class max.interfaces.MAXPullQueue(\*args, \*\*kwargs)
Protocol for a minimal, non-blocking pull queue interface in MAX.
This protocol defines the contract for a queue that supports non-blocking
get operations for retrieving items. It is generic over the item type and designed
for scenarios where the caller must be immediately notified if no items are available
rather than waiting for items to arrive.
The protocol is intended for consumer-side queue operations where immediate
feedback about queue state is critical for proper flow control and error handling.
### `get_nowait()` {#max.interfaces.MAXPullQueue.get_nowait}
> get\_nowait()
Remove and return an item from the queue without blocking.
This method is expected to raise queue.Empty if no item is available
to retrieve from the queue.
**Returns:**
The item removed from the queue.
**Return type:**
PullItemType
**Raises:**
[queue.Empty](https://docs.python.org/3/library/queue.html#queue.Empty) – If the queue is empty and no item can be retrieved.
## `MAXPushQueue` {#max.interfaces.MAXPushQueue}
> class max.interfaces.MAXPushQueue(\*args, \*\*kwargs)
Protocol for a minimal, non-blocking push queue interface in MAX.
This protocol defines the contract for a queue that supports non-blocking
put operations for adding items. It is generic over the item type and designed
for scenarios where the caller must be immediately notified of success or failure
rather than waiting for space to become available.
The protocol is intended for producer-side queue operations where immediate
feedback is critical for proper flow control and error handling.
### `put_nowait()` {#max.interfaces.MAXPushQueue.put_nowait}
> put\_nowait(item)
Attempt to put an item into the queue without blocking.
This method is designed to immediately fail (typically by raising an exception)
if the item cannot be added to the queue at the time of the call. Unlike the
traditional ‘put’ method in many queue implementations—which may block until
space becomes available or the transfer is completed—this method never waits.
It is intended for use cases where the caller must be notified of failure to
enqueue immediately, rather than waiting for space.
**Parameters:**
item (PushItemType) – The item to be added to the queue.
**Return type:**
None
## `OpenResponsesRequest` {#max.interfaces.OpenResponsesRequest}
> class max.interfaces.OpenResponsesRequest(request\_id, body)
General request container for OpenResponses API requests.
This class wraps an OpenResponsesRequestBody and adheres to the Request schema.
All request fields are accessed directly from the body.
**Parameters:**
* request\_id ([RequestID](#max.interfaces.RequestID))
* body (OpenResponsesRequestBody)
### `body` {#max.interfaces.OpenResponsesRequest.body}
> body: OpenResponsesRequestBody
The complete OpenResponses request body.
### `from_fastapi_request()` {#max.interfaces.OpenResponsesRequest.from_fastapi_request}
> async classmethod from\_fastapi\_request(request)
Create an OpenResponsesRequest from a FastAPI/Starlette Request.
Extracts the request\_id from request.state.request\_id and parses the
request body as an OpenResponsesRequestBody.
**Parameters:**
request (FastAPIRequestProtocol) – A request object with state.request\_id and body() method.
Compatible with FastAPI/Starlette Request objects.
**Returns:**
An OpenResponsesRequest instance.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If request.state.request\_id is not set.
* pydantic.ValidationError – If the request body is invalid.
## `Pipeline` {#max.interfaces.Pipeline}
> class max.interfaces.Pipeline
Abstract base class for pipeline operations.
This generic abstract class defines the interface for pipeline operations that
transform inputs of type PipelineInputsType into outputs of type PipelineOutputsDict\[PipelineOutputType].
All concrete pipeline implementations must inherit from this class and implement
the execute method.
Type Parameters:
: PipelineInputsType: The type of inputs this pipeline accepts, must inherit from PipelineInputs
PipelineOutputType: The type of outputs this pipeline produces, must be a subclass of PipelineOutput
```python
class MyPipeline(Pipeline[MyInputs, MyOutput]):
def execute(self, inputs: MyInputs) -> dict[RequestID, MyOutput]:
# Implementation here
pass
```
### `execute()` {#max.interfaces.Pipeline.execute}
> abstract execute(inputs)
Execute the pipeline operation with the given inputs.
This method must be implemented by all concrete pipeline classes.
It takes inputs of the specified type and returns outputs according
to the pipeline’s processing logic.
**Parameters:**
inputs (PipelineInputsType) – The input data for the pipeline operation, must be of type PipelineInputsType
**Returns:**
The results of the pipeline operation, as a dictionary mapping RequestID to PipelineOutputType
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If not implemented by a concrete subclass
### `release()` {#max.interfaces.Pipeline.release}
> abstract release(request\_id)
Release any resources or state associated with a specific request.
This method should be implemented by concrete pipeline classes to perform
cleanup or resource deallocation for the given request ID. It is typically
called when a request has completed processing and its associated resources
(such as memory, cache, or temporary files) are no longer needed.
**Parameters:**
request\_id ([RequestID](#max.interfaces.RequestID)) – The unique identifier of the request to release resources for.
**Returns:**
None
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If not implemented by a concrete subclass.
**Return type:**
None
## `PipelineInputs` {#max.interfaces.PipelineInputs}
> class max.interfaces.PipelineInputs
Base class representing inputs to a pipeline operation.
This class serves as a marker interface for all pipeline input types.
Concrete implementations should inherit from this class and define
the specific input data structures required for their pipeline operations.
```python
class MyPipelineInputs(PipelineInputs):
def __init__(self, data: str, config: dict):
self.data = data
self.config = config
```
## `PipelineOutput` {#max.interfaces.PipelineOutput}
> class max.interfaces.PipelineOutput(\*args, \*\*kwargs)
Protocol representing the output of a pipeline operation.
Subclasses must implement the is\_done property to indicate whether
the pipeline operation has completed.
### `is_done` {#max.interfaces.PipelineOutput.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
Indicates whether the pipeline operation has completed.
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the prompt exceeds the configured maximum length.
**Return type:**
TokenizerEncoded
### `eos` {#max.interfaces.PipelineTokenizer.eos}
> property eos: [int](https://docs.python.org/3/library/functions.html#int)
The end of sequence token for this tokenizer.
### `expects_content_wrapping` {#max.interfaces.PipelineTokenizer.expects_content_wrapping}
> property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool)
If true, this tokenizer expects messages to be wrapped as a dict.
Text messages are formatted as:
```json
{
"role": "user",
"content": [{ "type": "text", "text": "text content" }]
}
```
instead of:
```json
{ "role": "user", "content": "text_content" }
```
NOTE: Multimodal messages omit the content property.
Both `image_urls` and `image` content parts are converted to:
```json
{ "type": "image" }
```
Their content is provided as byte arrays through the top-level property
on the request object, i.e., `RequestType.images`.
### `new_context()` {#max.interfaces.PipelineTokenizer.new_context}
> async new\_context(request)
Creates a new context from a request object. This is sent to the
worker process once and then cached locally.
**Parameters:**
request (RequestType) – Incoming request.
**Returns:**
Initialized context.
**Return type:**
UnboundContextType
## `PipelinesFactory` {#max.interfaces.PipelinesFactory}
> max.interfaces.PipelinesFactory
Type alias for factory functions that create pipeline instances.
Factory functions should return a Pipeline with properly typed inputs and outputs
that are bound to the PipelineInputs and PipelineOutput base classes respectively.
This ensures type safety while maintaining flexibility for different pipeline implementations.
**Example:**
def create\_text\_pipeline() -> Pipeline\[TextGenerationInputs, TextGenerationOutput]:
: return MyTextGenerationPipeline()
factory: PipelinesFactory = create\_text\_pipeline
alias of [`Callable`](graph/ops.md#max.graph.ops.Callable)\[\[], [`Pipeline`](#max.interfaces.Pipeline)\[`PipelineInputsType`, `PipelineOutputType`]]
## `PixelGenerationContext` {#max.interfaces.PixelGenerationContext}
> class max.interfaces.PixelGenerationContext(\*args, \*\*kwargs)
Protocol defining the interface for pixel generation contexts.
A `PixelGenerationContext` represents model inputs for pixel generation pipelines,
managing the state and parameters needed for generating images or videos.
### `guidance_scale` {#max.interfaces.PixelGenerationContext.guidance_scale}
> property guidance\_scale: [float](https://docs.python.org/3/library/functions.html#float)
Classifier-free guidance scale (1.0 to disable CFG).
### `height` {#max.interfaces.PixelGenerationContext.height}
> property height: [int](https://docs.python.org/3/library/functions.html#int)
Height of generated output in pixels.
### `latents` {#max.interfaces.PixelGenerationContext.latents}
> property latents: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]
The latents for the context.
### `num_images_per_prompt` {#max.interfaces.PixelGenerationContext.num_images_per_prompt}
> property num\_images\_per\_prompt: [int](https://docs.python.org/3/library/functions.html#int)
Number of images to generate.
### `num_inference_steps` {#max.interfaces.PixelGenerationContext.num_inference_steps}
> property num\_inference\_steps: [int](https://docs.python.org/3/library/functions.html#int)
Number of denoising steps.
### `tokens` {#max.interfaces.PixelGenerationContext.tokens}
> property tokens: [TokenBuffer](#max.interfaces.TokenBuffer)
The token buffer for the context.
### `width` {#max.interfaces.PixelGenerationContext.width}
> property width: [int](https://docs.python.org/3/library/functions.html#int)
Width of generated output in pixels.
## `PixelGenerationInputs` {#max.interfaces.PixelGenerationInputs}
> class max.interfaces.PixelGenerationInputs(batch)
Input data structure for pixel generation pipelines.
This class represents the input data required for pixel generation operations
within the pipeline framework. It extends PipelineInputs and provides type-safe
generic support for different pixel generation context types.
### `batch` {#max.interfaces.PixelGenerationInputs.batch}
> batch: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[RequestID](#max.interfaces.RequestID), PixelGenerationContextType]
A dictionary mapping RequestID to PixelGenerationContextType instances.
This batch structure allows for processing multiple pixel generation
requests simultaneously while maintaining request-specific context
and configuration data.
## `PixelGenerationOutput` {#max.interfaces.PixelGenerationOutput}
> class max.interfaces.PixelGenerationOutput(request\_id, final\_status, pixel\_data=\)
Represents a response from the pixel generation API.
This class encapsulates the result of a pixel generation request, including
the request ID, final status, and generated pixel data.
### `final_status` {#max.interfaces.PixelGenerationOutput.final_status}
> final\_status: [GenerationStatus](#max.interfaces.GenerationStatus)
The final status of the generation process.
### `is_done` {#max.interfaces.PixelGenerationOutput.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
Indicates whether the pixel generation process is complete.
### `context` {#max.interfaces.ProcessorInputs.context}
> context: [TextGenerationContext](#max.interfaces.TextGenerationContext)
### `logits` {#max.interfaces.ProcessorInputs.logits}
> logits: md.Buffer
## `Request` {#max.interfaces.Request}
> class max.interfaces.Request(request\_id)
Base class representing a generic request within the MAX API.
This class provides a unique identifier for each request, ensuring that
all requests can be tracked and referenced consistently throughout the
system. Subclasses can extend this class to include additional fields
specific to their request types.
### `request_id` {#max.interfaces.Request.request_id}
> request\_id: [RequestID](#max.interfaces.RequestID)
## `RequestID` {#max.interfaces.RequestID}
> class max.interfaces.RequestID(value=\)
A unique immutable identifier for a request.
When instantiated without arguments, automatically generates a new
UUID4-based ID.
**Parameters:**
value ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The string identifier. If not provided, generates a UUID4 hex string.
### `value` {#max.interfaces.RequestID.value}
> value: [str](https://docs.python.org/3/library/stdtypes.html#str)
## `SamplingParams` {#max.interfaces.SamplingParams}
> class max.interfaces.SamplingParams(top\_k=-1, top\_p=1, min\_p=0.0, temperature=1, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0, max\_new\_tokens=None, min\_new\_tokens=0, ignore\_eos=False, stop=None, stop\_token\_ids=None, detokenize=True, seed=\, logits\_processors=None)
Request specific sampling parameters that are only known at run time.
### `detokenize` {#max.interfaces.SamplingParams.detokenize}
> detokenize: [bool](https://docs.python.org/3/library/functions.html#bool) = True
Whether to detokenize the output tokens into text.
### `frequency_penalty` {#max.interfaces.SamplingParams.frequency_penalty}
> frequency\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 0.0
The frequency penalty to apply to the model’s output. A positive value will penalize new tokens
based on their frequency in the generated text: tokens will receive a penalty proportional to the
count of appearances.
### `from_input_and_generation_config()` {#max.interfaces.SamplingParams.from_input_and_generation_config}
> classmethod from\_input\_and\_generation\_config(input\_params, sampling\_params\_defaults)
Create SamplingParams with defaults from HuggingFace’s GenerationConfig.
This method creates a SamplingParams instance by combining three sources of values,
in priority order (highest to lowest):
1. User-provided values in input\_params (non-None)
2. Model’s GenerationConfig values (only if explicitly set in the model’s config)
3. SamplingParams class defaults
**Parameters:**
* input\_params ([SamplingParamsInput](#max.interfaces.SamplingParamsInput)) – Dataclass containing user-specified parameter values.
Values of None will be replaced with model defaults or class defaults.
* sampling\_params\_defaults ([SamplingParamsGenerationConfigDefaults](#max.interfaces.SamplingParamsGenerationConfigDefaults)) – SamplingParamsGenerationConfigDefaults containing
default sampling parameters extracted from the model’s GenerationConfig.
**Returns:**
A new SamplingParams instance with model-aware defaults.
**Return type:**
[SamplingParams](#max.interfaces.SamplingParams)
**Example:**
```pycon
>>> sampling_defaults = model_config.sampling_params_defaults
>>> params = SamplingParams.from_input_and_generation_config(
... SamplingParamsInput(temperature=0.7), # User override
... sampling_params_defaults=sampling_defaults
... )
```
### `ignore_eos` {#max.interfaces.SamplingParams.ignore_eos}
> ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) = False
If True, the response will ignore the EOS token, and continue to
generate until the max tokens or a stop string is hit.
### `log_sampling_info()` {#max.interfaces.SamplingParams.log_sampling_info}
> log\_sampling\_info()
Log comprehensive sampling parameters information.
Displays all sampling parameters in a consistent visual format similar to
pipeline configuration logging.
**Return type:**
None
### `logits_processors` {#max.interfaces.SamplingParams.logits_processors}
> logits\_processors: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Callable](graph/ops.md#max.graph.ops.Callable)\[\[[ProcessorInputs](#max.interfaces.ProcessorInputs)], [None](https://docs.python.org/3/library/constants.html#None)]] | [None](https://docs.python.org/3/library/constants.html#None) = None
Callables to post-process the model logits.
See `LogitsProcessor` for examples.
### `max_new_tokens` {#max.interfaces.SamplingParams.max_new_tokens}
> max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
The maximum number of new tokens to generate in the response.
When set to an integer value, generation will stop after this many tokens.
When None (default), the model may generate tokens until it reaches its
internal limits or other stopping criteria are met.
### `min_new_tokens` {#max.interfaces.SamplingParams.min_new_tokens}
> min\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) = 0
The minimum number of tokens to generate in the response.
### `min_p` {#max.interfaces.SamplingParams.min_p}
> min\_p: [float](https://docs.python.org/3/library/functions.html#float) = 0.0
Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in \[0, 1]. Set to 0 to disable this.
### `presence_penalty` {#max.interfaces.SamplingParams.presence_penalty}
> presence\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 0.0
The presence penalty to apply to the model’s output. A positive value will penalize new tokens
that have already appeared in the generated text at least once by applying a constant penalty.
### `repetition_penalty` {#max.interfaces.SamplingParams.repetition_penalty}
> repetition\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 1.0
The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens
that have already appeared in the generated text at least once by dividing the logits by the
repetition penalty.
### `seed` {#max.interfaces.SamplingParams.seed}
> seed: [int](https://docs.python.org/3/library/functions.html#int)
The seed to use for the random number generator. Defaults to a cryptographically secure random value.
### `stop` {#max.interfaces.SamplingParams.stop}
> stop: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None
A list of detokenized sequences that can be used as stop criteria when generating a new sequence.
### `stop_token_ids` {#max.interfaces.SamplingParams.stop_token_ids}
> stop\_token\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None
A list of token ids that are used as stopping criteria when generating a new sequence.
### `temperature` {#max.interfaces.SamplingParams.temperature}
> temperature: [float](https://docs.python.org/3/library/functions.html#float) = 1
Controls the randomness of the model’s output; higher values produce more diverse responses.
For greedy sampling, set to temperature to 0.
### `top_k` {#max.interfaces.SamplingParams.top_k}
> top\_k: [int](https://docs.python.org/3/library/functions.html#int) = -1
Limits the sampling to the K most probable tokens. This defaults to -1 (to sample all tokens), for greedy sampling set to 1.
### `top_p` {#max.interfaces.SamplingParams.top_p}
> top\_p: [float](https://docs.python.org/3/library/functions.html#float) = 1
Only use the tokens whose cumulative probability is within the top\_p threshold. This applies to the top\_k tokens.
## `SamplingParamsGenerationConfigDefaults` {#max.interfaces.SamplingParamsGenerationConfigDefaults}
> class max.interfaces.SamplingParamsGenerationConfigDefaults(temperature=None, top\_p=None, top\_k=None, repetition\_penalty=None, max\_new\_tokens=None, min\_new\_tokens=None, do\_sample=None)
Default sampling parameter values extracted from a model’s GenerationConfig.
This class encapsulates sampling parameter defaults that come from a HuggingFace
model’s GenerationConfig. These defaults have middle priority when creating
SamplingParams instances:
Priority order (highest to lowest):
1. User-provided values (SamplingParamsInput)
2. Model’s GenerationConfig values (this class)
3. SamplingParams class defaults
All fields default to None, indicating that the model’s GenerationConfig does not
explicitly set that parameter. When None, SamplingParams will fall back to its
own class defaults.
**Example:**
```pycon
>>> # Extract from model config
>>> gen_config = model_config.generation_config
>>> defaults = SamplingParamsGenerationConfigDefaults(
... temperature=0.7,
... top_k=50,
... max_new_tokens=512
... )
>>> # Use with SamplingParams
>>> params = SamplingParams.from_input_and_generation_config(
... SamplingParamsInput(),
... sampling_params_defaults=defaults
... )
```
### `do_sample` {#max.interfaces.SamplingParamsGenerationConfigDefaults.do_sample}
> do\_sample: [bool](https://docs.python.org/3/library/functions.html#bool) | [None](https://docs.python.org/3/library/constants.html#None) = None
If False, use greedy sampling.
### `max_new_tokens` {#max.interfaces.SamplingParamsGenerationConfigDefaults.max_new_tokens}
> max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
Maximum number of new tokens from the model’s GenerationConfig, if explicitly set.
### `min_new_tokens` {#max.interfaces.SamplingParamsGenerationConfigDefaults.min_new_tokens}
> min\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
Minimum number of new tokens from the model’s GenerationConfig, if explicitly set.
### `repetition_penalty` {#max.interfaces.SamplingParamsGenerationConfigDefaults.repetition_penalty}
> repetition\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
Repetition penalty value from the model’s GenerationConfig, if explicitly set.
### `temperature` {#max.interfaces.SamplingParamsGenerationConfigDefaults.temperature}
> temperature: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
Temperature value from the model’s GenerationConfig, if explicitly set.
### `top_k` {#max.interfaces.SamplingParamsGenerationConfigDefaults.top_k}
> top\_k: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
Top-k sampling value from the model’s GenerationConfig, if explicitly set.
### `top_p` {#max.interfaces.SamplingParamsGenerationConfigDefaults.top_p}
> top\_p: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
Top-p (nucleus sampling) value from the model’s GenerationConfig, if explicitly set.
### `values_to_update` {#max.interfaces.SamplingParamsGenerationConfigDefaults.values_to_update}
> property values\_to\_update: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [float](https://docs.python.org/3/library/functions.html#float) | [int](https://docs.python.org/3/library/functions.html#int)]
Extract non-None field values as a dictionary.
**Returns:**
A dictionary mapping field names to their values, excluding any fields
that are None. This dictionary can be used to update SamplingParams
default values.
**Example:**
```pycon
>>> defaults = SamplingParamsGenerationConfigDefaults(
... temperature=0.7,
... top_k=50
... )
>>> defaults.values_to_update
{'temperature': 0.7, 'top_k': 50}
```
## `SamplingParamsInput` {#max.interfaces.SamplingParamsInput}
> class max.interfaces.SamplingParamsInput(top\_k=None, top\_p=None, min\_p=None, temperature=None, frequency\_penalty=None, presence\_penalty=None, repetition\_penalty=None, max\_new\_tokens=None, min\_new\_tokens=None, ignore\_eos=None, stop=None, stop\_token\_ids=None, detokenize=None, seed=None, logits\_processors=None)
Input dataclass for creating SamplingParams instances.
All fields are optional, allowing partial specification with None values
indicating “use default”. This enables static type checking while maintaining
the flexibility to specify only the parameters you want to override.
### `detokenize` {#max.interfaces.SamplingParamsInput.detokenize}
> detokenize: [bool](https://docs.python.org/3/library/functions.html#bool) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `frequency_penalty` {#max.interfaces.SamplingParamsInput.frequency_penalty}
> frequency\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `ignore_eos` {#max.interfaces.SamplingParamsInput.ignore_eos}
> ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `logits_processors` {#max.interfaces.SamplingParamsInput.logits_processors}
> logits\_processors: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Callable](graph/ops.md#max.graph.ops.Callable)\[\[[ProcessorInputs](#max.interfaces.ProcessorInputs)], [None](https://docs.python.org/3/library/constants.html#None)]] | [None](https://docs.python.org/3/library/constants.html#None) = None
### `max_new_tokens` {#max.interfaces.SamplingParamsInput.max_new_tokens}
> max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `min_new_tokens` {#max.interfaces.SamplingParamsInput.min_new_tokens}
> min\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `min_p` {#max.interfaces.SamplingParamsInput.min_p}
> min\_p: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `presence_penalty` {#max.interfaces.SamplingParamsInput.presence_penalty}
> presence\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `repetition_penalty` {#max.interfaces.SamplingParamsInput.repetition_penalty}
> repetition\_penalty: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `seed` {#max.interfaces.SamplingParamsInput.seed}
> seed: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `stop` {#max.interfaces.SamplingParamsInput.stop}
> stop: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None
### `stop_token_ids` {#max.interfaces.SamplingParamsInput.stop_token_ids}
> stop\_token\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None
### `temperature` {#max.interfaces.SamplingParamsInput.temperature}
> temperature: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `top_k` {#max.interfaces.SamplingParamsInput.top_k}
> top\_k: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `top_p` {#max.interfaces.SamplingParamsInput.top_p}
> top\_p: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
## `Scheduler` {#max.interfaces.Scheduler}
> class max.interfaces.Scheduler
Abstract base class defining the interface for schedulers.
### `run_iteration()` {#max.interfaces.Scheduler.run_iteration}
> abstract run\_iteration()
The core scheduler routine that creates and executes batches.
This method should implement the core scheduling logic including:
* Batch creation and management
* Request scheduling
## `SchedulerResult` {#max.interfaces.SchedulerResult}
> class max.interfaces.SchedulerResult(is\_done, result)
Structure representing the result of a scheduler operation for a specific pipeline execution.
This class encapsulates the outcome of a pipeline operation as managed by the scheduler,
including both the execution status and any resulting data from the pipeline. The scheduler
uses this structure to communicate the state of pipeline operations back to clients,
whether the operation is still running, has completed successfully, or was cancelled.
The generic type parameter allows this result to work with different types of pipeline
outputs while maintaining type safety.
**Parameters:**
* is\_done ([bool](https://docs.python.org/3/library/functions.html#bool))
* result (PipelineOutputType | None)
### `cancelled()` {#max.interfaces.SchedulerResult.cancelled}
> classmethod cancelled()
Create a SchedulerResult representing a cancelled pipeline operation.
### `create()` {#max.interfaces.SchedulerResult.create}
> classmethod create(result)
Create a SchedulerResult representing a pipeline operation with some result.
**Parameters:**
result (PipelineOutputType) – The pipeline output data.
### `is_done` {#max.interfaces.SchedulerResult.is_done}
> is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
The current status of the pipeline operation from the scheduler’s perspective.
### `result` {#max.interfaces.SchedulerResult.result}
> result: PipelineOutputType | [None](https://docs.python.org/3/library/constants.html#None)
The pipeline output data, if any. May be None for cancelled operations or during intermediate states of streaming operations.
## `SharedMemoryArray` {#max.interfaces.SharedMemoryArray}
> class max.interfaces.SharedMemoryArray(name, shape, dtype)
Wrapper for numpy array stored in shared memory.
This class is used as a placeholder in pixel\_values during serialization.
It will be encoded as a dict with \_\_shm\_\_ flag and decoded back to a numpy
array.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* shape ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), ...])
* dtype ([str](https://docs.python.org/3/library/stdtypes.html#str))
## `TextContentPart` {#max.interfaces.TextContentPart}
> class max.interfaces.TextContentPart(\*, type='text', text)
**Parameters:**
* type ([Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['text'])
* text ([str](https://docs.python.org/3/library/stdtypes.html#str))
### `model_config` {#max.interfaces.TextContentPart.model_config}
> model\_config: ClassVar\[ConfigDict] = {'frozen': True}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
### `text` {#max.interfaces.TextContentPart.text}
> text: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `type` {#max.interfaces.TextContentPart.type}
> type: Literal\['text']
## `TextGenerationContext` {#max.interfaces.TextGenerationContext}
> class max.interfaces.TextGenerationContext(\*args, \*\*kwargs)
Protocol defining the interface for text generation contexts in token generation.
A `TextGenerationContext` represents model inputs for text generation pipelines, managing
the state of tokens throughout the generation process. It handles token arrays,
generation status, sampling parameters, and various indices that track different
stages of token processing.
### `compute_num_available_steps()` {#max.interfaces.TextGenerationContext.compute_num_available_steps}
> compute\_num\_available\_steps(max\_seq\_len)
Compute the maximum number of generation steps available.
This method calculates how many tokens can be generated without
exceeding the specified maximum sequence length limit.
**Parameters:**
max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum allowed sequence length for this context.
**Returns:**
The number of generation steps that can be executed before reaching
the sequence length limit.
### `eos_token_ids` {#max.interfaces.TextGenerationContext.eos_token_ids}
> property eos\_token\_ids: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]
The set of end-of-sequence token IDs that can terminate generation.
**Returns:**
A set of token IDs that, when generated, will signal the end of the
sequence and terminate the generation process.
### `get_min_token_logit_mask()` {#max.interfaces.TextGenerationContext.get_min_token_logit_mask}
> get\_min\_token\_logit\_mask(num\_steps)
Get token indices that should be masked in the output logits.
This method is primarily used to implement the `min_tokens` constraint,
where certain tokens (typically EOS tokens) are masked to prevent early
termination before the minimum token count is reached.
**Parameters:**
num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – The number of generation steps to compute masks for.
**Returns:**
A list of NumPy arrays, where each array contains token indices
that should be masked (set to negative infinity) in the logits
for the corresponding generation step.
### `is_initial_prompt` {#max.interfaces.TextGenerationContext.is_initial_prompt}
> property is\_initial\_prompt: [bool](https://docs.python.org/3/library/functions.html#bool)
Whether this context contains only the initial prompt.
This property indicates if the context has not yet been updated with
any generated tokens and still contains only the original input.
**Returns:**
`True` if no tokens have been generated yet, `False` if generation
has begun and tokens have been added.
### `json_schema` {#max.interfaces.TextGenerationContext.json_schema}
> property json\_schema: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None)
The JSON schema for constrained decoding, if configured.
When set, this schema constrains token generation to produce valid JSON
output that conforms to the specified structure.
**Returns:**
The JSON schema string, or `None` if no schema constraint is active.
### `jump_ahead()` {#max.interfaces.TextGenerationContext.jump_ahead}
> jump\_ahead(new\_token)
Jump ahead in generation by adding a token and updating indices.
This method is used in speculative decoding scenarios to quickly
advance the generation state when draft tokens are accepted.
**Parameters:**
new\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The token ID to add when jumping ahead in the sequence.
**Return type:**
None
### `log_probabilities` {#max.interfaces.TextGenerationContext.log_probabilities}
> property log\_probabilities: [int](https://docs.python.org/3/library/functions.html#int)
The number of top tokens to return log probabilities for.
When greater than 0, the system returns log probabilities for the top N
most likely tokens at each generation step.
**Returns:**
The number of top tokens to include in log probability output.
Returns 0 if log probabilities are disabled.
### `log_probabilities_echo` {#max.interfaces.TextGenerationContext.log_probabilities_echo}
> property log\_probabilities\_echo: [bool](https://docs.python.org/3/library/functions.html#bool)
Whether to include input tokens in the returned log probabilities.
When `True`, log probabilities will be computed and returned for input
(prompt) tokens in addition to generated tokens.
**Returns:**
`True` if input tokens should be included in log probability output,
`False` otherwise.
### `matcher` {#max.interfaces.TextGenerationContext.matcher}
> property matcher: [Any](https://docs.python.org/3/library/typing.html#typing.Any) | [None](https://docs.python.org/3/library/constants.html#None)
The grammar matcher for structured output generation, if configured.
The matcher enforces structural constraints (like JSON schema) during
generation to ensure valid formatted output.
**Returns:**
The grammar matcher instance, or None if no structured generation
is configured for this context.
:::note Note
The matcher type depends on the structured generation backend used
(e.g., outlines, guidance, etc.). In the future, this should be
replaced with a Protocol for better type safety.
:::
### `max_length` {#max.interfaces.TextGenerationContext.max_length}
> property max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
The maximum allowed length for this sequence.
When set, generation will stop when this length is reached, regardless
of other stopping criteria.
**Returns:**
The maximum sequence length limit, or `None` if no limit is set.
### `min_tokens` {#max.interfaces.TextGenerationContext.min_tokens}
> property min\_tokens: [int](https://docs.python.org/3/library/functions.html#int)
The minimum number of new tokens that must be generated.
Generation will continue until at least this many new tokens have been
produced, even if other stopping criteria are met (e.g., EOS tokens).
**Returns:**
The minimum number of new tokens to generate.
### `realize_future_token()` {#max.interfaces.TextGenerationContext.realize_future_token}
> realize\_future\_token(new\_token, log\_probabilities=None)
Overwrite the placeholder future token with the actual token.
This is primarily used for overlap scheduling.
### `reset()` {#max.interfaces.TextGenerationContext.reset}
> reset()
Resets the context’s state by combining all tokens into a new prompt.
This method is used when a request is evicted, meaning that the context
needed to be re-encoded in the following CE iteration.
**Return type:**
None
### `sampling_params` {#max.interfaces.TextGenerationContext.sampling_params}
> property sampling\_params: [SamplingParams](#max.interfaces.SamplingParams)
The sampling parameters configured for this generation request.
These parameters control how tokens are selected during generation,
including temperature, top-k/top-p filtering, and stopping criteria.
**Returns:**
The `SamplingParams` instance containing all sampling configuration
for this context.
### `set_matcher()` {#max.interfaces.TextGenerationContext.set_matcher}
> set\_matcher(matcher)
Set a grammar matcher for constrained decoding.
This method configures structured output generation by installing a
grammar matcher that enforces format constraints during token generation.
**Parameters:**
matcher ([Any](https://docs.python.org/3/library/typing.html#typing.Any)) – The grammar matcher instance to use for constraining output.
The specific type depends on the structured generation backend.
**Return type:**
None
### `to_generation_output()` {#max.interfaces.TextGenerationContext.to_generation_output}
> to\_generation\_output()
Convert this context to a TextGenerationOutput object.
This property provides a standardized way to extract the final output
of the text generation process from the context, including generated
text, tokens, and any associated metadata.
**Returns:**
The output object containing the results of
the text generation for this context.
### `tokens` {#max.interfaces.TextGenerationContext.tokens}
> property tokens: [TokenBuffer](#max.interfaces.TokenBuffer)
The token buffer for the context.
### `update()` {#max.interfaces.TextGenerationContext.update}
> update(new\_token, log\_probabilities=None)
Update the context with a newly generated token, and update status.
This method adds a generated token to the context, updating the token
array, associated metadata, and log probabilities (if provided).
It is also responsible for updating the context’s generation status and
determining if the generation sequence is complete, either due to reaching
an end-of-sequence condition or meeting stopping criteria.
**Parameters:**
* new\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The token ID to add to the generation sequence.
* log\_probabilities ([LogProbabilities](#max.interfaces.LogProbabilities) | None) – Optional log probability data for the new token
and alternatives. Used for analysis and debugging.
**Return type:**
None
### `update_with_future_token()` {#max.interfaces.TextGenerationContext.update_with_future_token}
> update\_with\_future\_token()
Append a placeholder future token to the generated tokens.
This is primarily used for overlap scheduling.
**Return type:**
None
## `TextGenerationInputs` {#max.interfaces.TextGenerationInputs}
> class max.interfaces.TextGenerationInputs(batches, num\_steps, input\_tokens=-1, batch\_type=BatchType.TG)
Input parameters for text generation pipeline operations.
This class encapsulates the batch of contexts and number of steps required
for token generation in a single input object, replacing the previous
pattern of passing batch and num\_steps as separate parameters.
### `batch_echo` {#max.interfaces.TextGenerationInputs.batch_echo}
> property batch\_echo: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[bool](https://docs.python.org/3/library/functions.html#bool)]
List indicating whether echo is enabled for each context in the batch.
### `batch_top_log_probs` {#max.interfaces.TextGenerationInputs.batch_top_log_probs}
> property batch\_top\_log\_probs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
List of requested top log probabilities per context in the batch.
### `batch_type` {#max.interfaces.TextGenerationInputs.batch_type}
> batch\_type: [BatchType](#max.interfaces.BatchType) = 'TG'
Type of batch.
### `batches` {#max.interfaces.TextGenerationInputs.batches}
> batches: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType]]
Variable list of batches, with each batch being a list of contexts.
There can be multiple batches when using data parallelism, in which each
batch is mapped to a different device replica.
### `enable_echo` {#max.interfaces.TextGenerationInputs.enable_echo}
> property enable\_echo: [bool](https://docs.python.org/3/library/functions.html#bool)
Return True if any context in the batch has echo enabled.
### `enable_log_probs` {#max.interfaces.TextGenerationInputs.enable_log_probs}
> property enable\_log\_probs: [bool](https://docs.python.org/3/library/functions.html#bool)
Return True if any context in the batch requests log probabilities.
### `flat_batch` {#max.interfaces.TextGenerationInputs.flat_batch}
> property flat\_batch: [list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType]
Flattened list of contexts across all replicas.
### `input_tokens` {#max.interfaces.TextGenerationInputs.input_tokens}
> input\_tokens: [int](https://docs.python.org/3/library/functions.html#int) = -1
Number of input tokens.
### `num_steps` {#max.interfaces.TextGenerationInputs.num_steps}
> num\_steps: [int](https://docs.python.org/3/library/functions.html#int)
Number of steps to run for.
## `TextGenerationOutput` {#max.interfaces.TextGenerationOutput}
> class max.interfaces.TextGenerationOutput(\*, request\_id, tokens, final\_status, log\_probabilities=None)
Represents the output of a text generation operation, combining token IDs,
final generation status, request ID, and optional log probabilities for each token.
### `final_status` {#max.interfaces.TextGenerationOutput.final_status}
> final\_status: [GenerationStatus](#max.interfaces.GenerationStatus)
The final status of the generation process.
### `is_done` {#max.interfaces.TextGenerationOutput.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
Indicates whether the text generation process is complete.
### `chat_template_options` {#max.interfaces.TextGenerationRequest.chat_template_options}
> chat\_template\_options: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [None](https://docs.python.org/3/library/constants.html#None) = None
Optional dictionary of options to pass when applying the chat template.
### `echo` {#max.interfaces.TextGenerationRequest.echo}
> echo: [bool](https://docs.python.org/3/library/functions.html#bool) = False
If set to True, the response will include the original prompt along with the
generated output. This can be useful for debugging or when you want to see how
the input relates to the output.
### `images` {#max.interfaces.TextGenerationRequest.images}
> images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[bytes](https://docs.python.org/3/library/stdtypes.html#bytes)]
A list of image byte arrays that can be included as part of the request.
This field is optional and may be used for multimodal inputs where images
are relevant to the prompt or task.
### `logprobs` {#max.interfaces.TextGenerationRequest.logprobs}
> logprobs: [int](https://docs.python.org/3/library/functions.html#int) = 0
The number of top log probabilities to return for each generated token. A value
of 0 means that log probabilities will not be returned. Useful for analyzing
model confidence in its predictions.
### `messages` {#max.interfaces.TextGenerationRequest.messages}
> messages: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestMessage](#max.interfaces.TextGenerationRequestMessage)]
A list of messages for chat-based interactions. This is used in chat
completion APIs, where each message represents a turn in the conversation.
If provided, the model will generate responses based on these messages.
### `model_name` {#max.interfaces.TextGenerationRequest.model_name}
> model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str)
The name of the model to be used for generating tokens. This should match
the available models on the server and determines the behavior and
capabilities of the response generation.
### `number_of_images` {#max.interfaces.TextGenerationRequest.number_of_images}
> property number\_of\_images: [int](https://docs.python.org/3/library/functions.html#int)
Returns the total number of image-type contents across all provided messages.
**Returns:**
Total count of image-type contents found in messages.
### `prompt` {#max.interfaces.TextGenerationRequest.prompt}
> prompt: [str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None
The prompt to be processed by the model. This field supports legacy
completion APIs and can accept either a string or a sequence of integers
representing token IDs. If not provided, the model may generate output
based on the messages field.
### `request_path` {#max.interfaces.TextGenerationRequest.request_path}
> request\_path: [str](https://docs.python.org/3/library/stdtypes.html#str) = '/'
The endpoint path for the request. This is typically used for routing and
logging requests within the server infrastructure.
### `response_format` {#max.interfaces.TextGenerationRequest.response_format}
> response\_format: [TextGenerationResponseFormat](#max.interfaces.TextGenerationResponseFormat) | [None](https://docs.python.org/3/library/constants.html#None) = None
Specifies the desired format for the model’s output. When set, it enables
structured generation, which adheres to the json\_schema provided.
### `sampling_params` {#max.interfaces.TextGenerationRequest.sampling_params}
> sampling\_params: [SamplingParams](#max.interfaces.SamplingParams)
Token sampling configuration parameters for the request.
### `stop` {#max.interfaces.TextGenerationRequest.stop}
> stop: [str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None
//platform.openai.com/docs/api-reference/chat/create#chat-create-stop)
**Type:**
Optional list of stop expressions (see
**Type:**
https
### `target_endpoint` {#max.interfaces.TextGenerationRequest.target_endpoint}
> target\_endpoint: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None
Optional target endpoint identifier for routing the request to a specific
service or model instance. This should be used in disaggregate serving
scenarios, when you want to dynamically route to a specific instance.
If not specified, the request will be routed to the default endpoint.
### `timestamp_ns` {#max.interfaces.TextGenerationRequest.timestamp_ns}
> timestamp\_ns: [int](https://docs.python.org/3/library/functions.html#int) = 0
The time (in nanoseconds) when the request was received by the server. This
can be useful for performance monitoring and logging purposes.
### `tools` {#max.interfaces.TextGenerationRequest.tools}
> tools: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextGenerationRequestTool](#max.interfaces.TextGenerationRequestTool)] | [None](https://docs.python.org/3/library/constants.html#None) = None
A list of tools that can be invoked during the generation process. This
allows the model to utilize external functionalities or APIs to enhance its
responses.
## `TextGenerationRequestFunction` {#max.interfaces.TextGenerationRequestFunction}
> class max.interfaces.TextGenerationRequestFunction
Represents a function definition for a text generation request.
### `description` {#max.interfaces.TextGenerationRequestFunction.description}
> description: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None)
A human-readable description of the function’s purpose.
### `name` {#max.interfaces.TextGenerationRequestFunction.name}
> name: [str](https://docs.python.org/3/library/stdtypes.html#str)
The name of the function to be invoked.
### `parameters` {#max.interfaces.TextGenerationRequestFunction.parameters}
> parameters: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]
A dictionary describing the function’s parameters, typically following a JSON schema.
## `TextGenerationRequestMessage` {#max.interfaces.TextGenerationRequestMessage}
> class max.interfaces.TextGenerationRequestMessage(\*, role, content)
### `model_config` {#max.interfaces.TextGenerationRequestMessage.model_config}
> model\_config: ClassVar\[ConfigDict] = {'from\_attributes': True, 'frozen': True}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
### `number_of_images` {#max.interfaces.TextGenerationRequestMessage.number_of_images}
> property number\_of\_images: [int](https://docs.python.org/3/library/functions.html#int)
Returns the number of ImageContentPart instances in the message content.
### `role` {#max.interfaces.TextGenerationRequestMessage.role}
> role: MessageRole
### `validate_content_format()` {#max.interfaces.TextGenerationRequestMessage.validate_content_format}
> classmethod validate\_content\_format(v)
**Parameters:**
v ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
## `TextGenerationRequestTool` {#max.interfaces.TextGenerationRequestTool}
> class max.interfaces.TextGenerationRequestTool
Represents a tool definition for a text generation request.
### `function` {#max.interfaces.TextGenerationRequestTool.function}
> function: [TextGenerationRequestFunction](#max.interfaces.TextGenerationRequestFunction)
The function definition associated with the tool, including its name, description, and parameters.
### `type` {#max.interfaces.TextGenerationRequestTool.type}
> type: [str](https://docs.python.org/3/library/stdtypes.html#str)
The type of the tool, typically indicating the tool’s category or usage.
## `TextGenerationResponseFormat` {#max.interfaces.TextGenerationResponseFormat}
> class max.interfaces.TextGenerationResponseFormat
Represents the response format specification for a text generation request.
### `json_schema` {#max.interfaces.TextGenerationResponseFormat.json_schema}
> json\_schema: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]
A JSON schema dictionary that defines the structure and validation rules for the generated response.
### `type` {#max.interfaces.TextGenerationResponseFormat.type}
> type: [str](https://docs.python.org/3/library/stdtypes.html#str)
The type of response format, e.g., “json\_object”.
## `TokenBuffer` {#max.interfaces.TokenBuffer}
> class max.interfaces.TokenBuffer(array)
A dynamically resizable container for managing token sequences.
TokenBuffer provides efficient storage and access to token sequences during
text generation. It maintains the prompt tokens (initial input) and generated
tokens (model output) separately, while handling automatic memory management
as new tokens are added.
TokenBuffer organizes tokens across three related views:
1. The full stored sequence (all), split into prompt and generated tokens.
2. The processing window (active versus processed and pending tokens).
3. The streaming window over newly generated tokens consumed by callers.
The first diagram shows how prompt and generated tokens share a single
backing array. Later diagrams explain how processing and streaming walk
over that array during generation:
```default
+-------------------- all --------------------+
+-----------------+---------------------------+
| prompt | generated |
+-----------------+---------------------------+
0 prompt_length ^ generated_length ^
0 len(self) ^
```
This includes three attributes for accessing tokens:
* all: The slice of the array containing all valid tokens.
* prompt: The slice of the array containing the prompt tokens.
* generated: The slice of the array containing the generated tokens.
Along with three attributes for tracking their lengths:
* prompt\_length: The number of tokens in the prompt.
* generated\_length: The number of tokens in the generated tokens.
* len(self): The total number of valid tokens in the buffer.
Processing window (what the model will process next):
```default
+-------------------------------- all -------------------------+
+-------------------+---------------------------+---------------+
| processed | active | pending |
+-------------------+---------------------------+---------------+
0 processed_length ^ active_length ^
0 current_position ^
0 len(self) ^
```
In the above, processed tracks tokens which has already been processed,
active tracks tokens, which are scheduled to be processed in the next batch,
and pending tracks tokens, which have not yet been processed, but are not
actively scheduled to be processed in the next batch (this commonly
occurs during chunked prefill).
This includes one attribute for accessing tokens:
* active: The slice of the array containing the tokens scheduled
for processing in the next batch.
Along with three additional attributes for tracking their lengths:
* processed\_length: The number of tokens that have already been processed.
* active\_length: The number of tokens that is currently scheduled for
processing in the next batch.
* current\_position: The global index marking the end of the current
active processing window.
This processing view is updated by method such as rewind\_processing,
skip\_processing, chunk, and advance\_chunk/advance\_with\_token. Which
control how much of the existing sequence is reprocessed or advanced at
each step.
It also maintains a completion window over the generated tokens
for completion streaming:
```default
+------------- generated -------------+
+------------+------------------------+
| streamed | ready to stream next |
+------------+------------------------+
| (1) | (2) |
```
Generated tokens are conceptually split into:
1. **streamed**: tokens that have already been returned to the caller.
2. **read to stream**: the newest generated tokens that have not yet
been returned.
Each call to consume\_recently\_generated\_tokens() returns the (2) region
and advances the boundary between (1) and (2), so subsequent calls only
see newly generated tokens.
Together, these three views let TokenBuffer support efficient prompt
handling, chunked processing, and incremental streaming while exposing a small,
consistent public API.
### `active` {#max.interfaces.TokenBuffer.active}
> property active: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]]
Return the tokens queued for the next processing step.
### `active_length` {#max.interfaces.TokenBuffer.active_length}
> property active\_length: [int](https://docs.python.org/3/library/functions.html#int)
Count of tokens currently scheduled for processing.
### `actively_chunked` {#max.interfaces.TokenBuffer.actively_chunked}
> property actively\_chunked: [bool](https://docs.python.org/3/library/functions.html#bool)
Check if the buffer has active chunk limits applied.
**Returns:**
True if chunk limits are active, False otherwise.
### `advance_chunk()` {#max.interfaces.TokenBuffer.advance_chunk}
> advance\_chunk()
Move to the next set of tokens after a limited chunk.
Call this after maybe\_chunk when you have finished working with the
current active tokens and want the remaining tokens in the sequence
to become active.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If called before maybe\_chunk has limited the active
tokens (i.e., when no chunk is currently active).
**Return type:**
None
### `advance_with_token()` {#max.interfaces.TokenBuffer.advance_with_token}
> advance\_with\_token(token, mark\_previous\_as\_processed=True)
Add a new token to the buffer.
**Parameters:**
* token ([int](https://docs.python.org/3/library/functions.html#int)) – The token ID to add.
* mark\_previous\_as\_processed ([bool](https://docs.python.org/3/library/functions.html#bool)) – If False, expands the set of active tokens instead of
shifting forward. This is useful for speculative execution
scenarios where multiple tokens may be generated.
**Return type:**
None
### `all` {#max.interfaces.TokenBuffer.all}
> property all: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]]
Return every valid token currently stored (prompt + generated).
Use this when downstream components need the full sequence for scoring,
logging, or serialization.
### `apply_processing_offset()` {#max.interfaces.TokenBuffer.apply_processing_offset}
> apply\_processing\_offset(value)
Set the processing offset.
**Parameters:**
value ([int](https://docs.python.org/3/library/functions.html#int)) – The new processing offset.
**Return type:**
None
### `array` {#max.interfaces.TokenBuffer.array}
> array: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]]
In-place storage holding the prompt plus any generated tokens.
### `chunk()` {#max.interfaces.TokenBuffer.chunk}
> chunk(chunk\_size)
Limit the upcoming processing step to at most n tokens.
**Parameters:**
chunk\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum number of tokens to process.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If chunk\_size is not between 1 and the current number of active tokens.
**Return type:**
None
### `consume_recently_generated_tokens()` {#max.interfaces.TokenBuffer.consume_recently_generated_tokens}
> consume\_recently\_generated\_tokens()
Return newly generated tokens since the last consumption.
**Returns:**
A slice containing tokens ready to stream to the caller.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no new tokens are available.
### `current_position` {#max.interfaces.TokenBuffer.current_position}
> property current\_position: [int](https://docs.python.org/3/library/functions.html#int)
Global index marking the end of the current active processing window.
Equal to processed\_length + active\_length; represents the index of
the next token to be processed, which may be less than the total length
when processing is limited by chunking.
### `generated` {#max.interfaces.TokenBuffer.generated}
> property generated: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]]
Return all tokens produced after the prompt.
Use this slice for stop checks, repetition penalties, or any logic that
should consider only newly generated content.
### `generated_length` {#max.interfaces.TokenBuffer.generated_length}
> property generated\_length: [int](https://docs.python.org/3/library/functions.html#int)
Number of tokens generated after the prompt.
### `has_outstanding_generated_tokens` {#max.interfaces.TokenBuffer.has_outstanding_generated_tokens}
> property has\_outstanding\_generated\_tokens: [bool](https://docs.python.org/3/library/functions.html#bool)
Indicates whether there are generated tokens that have not yet been consumed.
**Returns:**
True if there are outstanding generated tokens to be streamed or processed; False otherwise.
### `overwrite_last_token()` {#max.interfaces.TokenBuffer.overwrite_last_token}
> overwrite\_last\_token(token)
Overwrite the last token in the buffer.
### `processed_length` {#max.interfaces.TokenBuffer.processed_length}
> property processed\_length: [int](https://docs.python.org/3/library/functions.html#int)
Number of tokens that have already been processed.
### `prompt` {#max.interfaces.TokenBuffer.prompt}
> property prompt: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int64]]
Return only the original prompt tokens.
Helpful for echo suppression, prompt-side metrics, or offset
calculations that should exclude generated output.
### `prompt_length` {#max.interfaces.TokenBuffer.prompt_length}
> property prompt\_length: [int](https://docs.python.org/3/library/functions.html#int)
Number of tokens that belong to the prompt.
### `reset_as_new_prompt()` {#max.interfaces.TokenBuffer.reset_as_new_prompt}
> reset\_as\_new\_prompt()
Treat the current sequence as a fresh prompt.
Marks all existing tokens as prompt tokens so the next generation pass
starts from this state.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the buffer state is invalid.
**Return type:**
None
### `rewind_processing()` {#max.interfaces.TokenBuffer.rewind_processing}
> rewind\_processing(n)
Re-expose n earlier tokens so they can be processed again.
**Parameters:**
n ([int](https://docs.python.org/3/library/functions.html#int)) – Number of tokens to move back into the active window.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If n is negative.
**Return type:**
None
### `skip_processing()` {#max.interfaces.TokenBuffer.skip_processing}
> skip\_processing(n)
Advance the active window start by n tokens.
**Parameters:**
n ([int](https://docs.python.org/3/library/functions.html#int)) – Number of tokens to drop from the active window.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If n exceeds the number of available tokens to process,
or if skipping n tokens would leave 0 active tokens.
**Return type:**
None
## `VLMTextGenerationContext` {#max.interfaces.VLMTextGenerationContext}
> class max.interfaces.VLMTextGenerationContext(\*args, \*\*kwargs)
Protocol defining the interface for VLM input contexts.
### `compute_image_aligned_idx()` {#max.interfaces.VLMTextGenerationContext.compute_image_aligned_idx}
> compute\_image\_aligned\_idx(idx)
Possibly aligns a index value downward if it lies in the middle of an image.
### `image_idx` {#max.interfaces.VLMTextGenerationContext.image_idx}
> property image\_idx: [int](https://docs.python.org/3/library/functions.html#int)
Index of the next unencoded image in the prompt.
### `images` {#max.interfaces.VLMTextGenerationContext.images}
> property images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](#max.interfaces.ImageMetadata)]
Returns the images in the context.
### `needs_vision_encoding` {#max.interfaces.VLMTextGenerationContext.needs_vision_encoding}
> property needs\_vision\_encoding: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns whether vision encoding is needed for this context.
### `next_images` {#max.interfaces.VLMTextGenerationContext.next_images}
> property next\_images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](#max.interfaces.ImageMetadata)]
Returns the images that are not yet encoded.
## `drain_queue()` {#max.interfaces.drain_queue}
> max.interfaces.drain\_queue(pull\_queue, max\_items=None)
Remove and return items from the queue without blocking.
This method is expected to return an empty list if the queue is empty.
If max\_items is specified, at most that many items will be returned.
**Parameters:**
* pull\_queue ([MAXPullQueue](#max.interfaces.MAXPullQueue)\[PullItemType]) – The queue to drain items from.
* max\_items ([int](https://docs.python.org/3/library/functions.html#int) | None) – Maximum number of items to return. If None, returns all items.
**Returns:**
List of items removed from the queue, limited by max\_items if specified.
## `get_blocking()` {#max.interfaces.get_blocking}
> max.interfaces.get\_blocking(pull\_queue)
Get the next item from the queue.
If no item is available, this method will spin until one is.
## `msgpack_numpy_decoder()` {#max.interfaces.msgpack_numpy_decoder}
> max.interfaces.msgpack\_numpy\_decoder(type\_, copy=False)
Create a decoder function for the specified type.
**Parameters:**
* type – The type to decode into
* copy ([bool](https://docs.python.org/3/library/functions.html#bool)) – Copy numpy arrays if true. Defaults to True.
Copy is set to True by default because most downstream usage of deserialized tensors are MAX driver tensors, which require owned numpy arrays.
This is a constraint imposed by dlpack & numpy where we cannot create a buffer from read-only data.
While there is a performance benefit during deserialization to removing copies by default, this often just moves the work downstream to an implicit copy during Buffer.from\_numpy.
As a result, it is easier to make the copy explicit here and maintain the pattern that all numpy arrays used in MAX are owned by the current process.
* type\_ ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
**Returns:**
A pickleable decoder instance that decodes bytes into the specified type
**Return type:**
MsgpackNumpyDecoder
## `msgpack_numpy_encoder()` {#max.interfaces.msgpack_numpy_encoder}
> max.interfaces.msgpack\_numpy\_encoder(use\_shared\_memory=False, shared\_memory\_threshold=0)
Create an encoder function that handles numpy arrays.
**Parameters:**
* use\_shared\_memory ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to attempt shared memory conversion for numpy arrays
* shared\_memory\_threshold ([int](https://docs.python.org/3/library/functions.html#int)) – Minimum size in bytes for shared memory conversion.
If 0, all arrays are candidates for conversion.
**Returns:**
A pickleable encoder instance that encodes objects into bytes
**Return type:**
MsgpackNumpyEncoder
---
## kv_cache
KV cache management for efficient attention computation during inference.
This package provides implementations for managing key-value caches used in
transformer models. The paged attention implementation enables efficient memory
management by fragmenting cache memory into pages, allowing for better memory
utilization and support for prefix caching.
## Functions
* [`load_kv_manager`](/max/api/python/kv_cache/registry): Load and initialize a KV cache manager.
* [`estimate_kv_cache_size`](/max/api/python/kv_cache/registry): Estimate KV cache memory requirements.
* [`infer_optimal_batch_size`](/max/api/python/kv_cache/registry): Infer optimal batch size based on available cache memory.
* [`available_port`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Find an available TCP port for transfer engine communication.
## Modules
* [`registry`](/max/api/python/kv_cache/registry): KV cache manager factory functions and utilities.
## Packages
* [`paged_kv_cache`](/max/api/python/kv_cache/paged_kv_cache): Paged attention KV cache implementation.
## Classes
* [`PagedKVCacheManager`](/max/api/python/kv_cache/paged_kv_cache/cache_manager): Manager for paged KV cache with data and tensor parallelism support.
* [`KVTransferEngine`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Manages KV cache transfers between devices in distributed settings.
* [`KVTransferEngineMetadata`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Metadata for KV cache transfer engine configuration.
* [`TransferReqData`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Data structure for KV cache transfer requests.
---
## cache_manager
## `PagedKVCacheManager` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager}
> class max.kv\_cache.paged\_kv\_cache.cache\_manager.PagedKVCacheManager(params, session, total\_num\_pages, total\_num\_host\_pages=0, enable\_runtime\_checks=False)
Paged KVCache manager with data and tensor parallelism support.
```python
kv_manager.claim(ctx1.request_id, replica_idx=0)
kv_manager.claim(ctx2.request_id, replica_idx=1)
# Allocate blocks for these requests
kv_manager.alloc(ctx1, replica_idx=0, num_steps=10)
kv_manager.alloc(ctx2, replica_idx=1, num_steps=10)
# Get KVCache inputs to feed to graph
kv_cache_inputs = kv_manager.get_runtime_inputs(
[[ctx1, ctx2]], num_steps=10
)
# Run model...
# Update requests with newly generated tokens
ctx1.update(42)
ctx2.update(42)
# Commit newly written blocks to prefix cache
kv_manager.step([[ctx1, ctx2]])
# Release metadata and KV blocks for these requests
kv_manager.release(ctx1.request_id, replica_idx=0)
kv_manager.release(ctx2.request_id, replica_idx=1)
```
### `alloc()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.alloc}
> alloc(data, replica\_idx, num\_steps=1)
Allocates blocks for a request to run for N steps.
This method allocates blocks needed by a request to run for N steps.
When prefix caching is enabled, some of the allocated blocks may be
retrieved from the prefix cache.
**Parameters:**
* data ([TextGenerationContext](../../interfaces.md#max.interfaces.TextGenerationContext)) – The text generation context for the request. The request ID
must already be assigned to a replica via claim.
* num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – The number of steps to reserve blocks for. Default: 1.
* replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Raises:**
* InsufficientBlocksError – If there are insufficient free blocks to
* satisfy the allocation. –
**Return type:**
None
### `claim()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.claim}
> claim(request\_id, replica\_idx)
Reserve a sequence ID for the given request ID.
### `get_pct_used_blocks_after_allocation()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_pct_used_blocks_after_allocation}
> get\_pct\_used\_blocks\_after\_allocation(ctx, replica\_idx, num\_steps=1)
Get the percentage of blocks used after allocating for a request.
**Parameters:**
* ctx ([TextGenerationContext](../../interfaces.md#max.interfaces.TextGenerationContext)) – The request context containing sequence information and token indices.
* num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Number of additional steps to allocate blocks for. Defaults to 1.
* replica\_idx ([int](https://docs.python.org/3/library/functions.html#int))
**Returns:**
The percentage of total blocks used after allocating for the request.
### `get_runtime_inputs()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.get_runtime_inputs}
> get\_runtime\_inputs(batches, num\_steps=1)
Get the graph inputs for per-replica batches of requests.
This method will raise a RuntimeError if any request has insufficient blocks
already allocated to it to run for the given number of steps.
**Parameters:**
* batches ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TextGenerationContext](../../interfaces.md#max.interfaces.TextGenerationContext)]]) – Per-replica batches of requests
* num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Number of steps to run for
### `step()` {#max.kv_cache.paged_kv_cache.cache_manager.PagedKVCacheManager.step}
> step(batches)
Commit new tokens into the prefix cache for per-replica batches.
---
## paged_kv_cache
Paged attention KV cache implementation with support for distributed inference.
This package provides the core implementation of paged KV cache management,
including cache managers, transfer engines for distributed settings, and tensor
parallelism support.
## Modules
* [`cache_manager`](/max/api/python/kv_cache/paged_kv_cache/cache_manager): Core paged KV cache manager implementation.
* [`tp_cache_manager`](/max/api/python/kv_cache/paged_kv_cache/tp_cache_manager): Tensor parallelism cache manager and input symbols.
* [`transfer_engine`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): KV cache transfer engine for distributed inference.
## Classes
* [`PagedKVCacheManager`](/max/api/python/kv_cache/paged_kv_cache/cache_manager): Manager for paged KV cache with data and tensor parallelism support.
* [`KVTransferEngine`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Manages KV cache transfers between devices.
* [`KVTransferEngineMetadata`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Metadata for transfer engine configuration.
* [`TransferReqData`](/max/api/python/kv_cache/paged_kv_cache/transfer_engine): Transfer request data structure.
---
## tp_cache_manager
PagedAttention-enabled KV cache for the Transformer leveraging the mo.opaque pattern.
---
## transfer_engine
KVCache Transfer Engine
## `KVTransferEngine` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine}
> class max.kv\_cache.paged\_kv\_cache.transfer\_engine.KVTransferEngine(name, tensors, \*, total\_num\_pages)
KVCache Transfer Engine with support for Data Parallelism (DP) and Tensor Parallelism (TP).
The engine accepts a 2D list of tensors: list\[list\[Buffer]] where the outer list
represents DP replicas and the inner list represents TP shards within each replica.
The TransferEngine communicates with other TransferEngines in other threads
or processes. However, individual TransferEngines themselves are not
thread-safe. It is intended to be used by MAX’s single-threaded scheduler.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* tensors (Sequence\[Sequence\[[Buffer](../../driver.md#max.driver.Buffer)]])
* total\_num\_pages ([int](https://docs.python.org/3/library/functions.html#int))
### `bytes_per_page` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.bytes_per_page}
> bytes\_per\_page: [int](https://docs.python.org/3/library/functions.html#int)
Bytes per page for each tensor.
### `cleanup()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.cleanup}
> cleanup()
Release all resources associated with the transfer engine.
Should be called before the transfer engine is garbage collected.
Moving this logic into the \_\_del\_\_ destructor does causes a UCX error for
unknown reasons.
**Return type:**
None
### `cleanup_transfer()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.cleanup_transfer}
> cleanup\_transfer(transfer\_req)
Cleanup a transfer. This should be called after a transfer is complete.
**Parameters:**
transfer\_req ([TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData)) – The transfer request to cleanup.
**Return type:**
None
### `completed_recv_transfers` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.completed_recv_transfers}
> completed\_recv\_transfers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [int](https://docs.python.org/3/library/functions.html#int)]]
Map of agent names to completed recv transfers.
### `connect()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.connect}
> connect(remote)
Connect to a remote engine (all replicas).
**Parameters:**
remote ([KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata)) – Metadata for the remote engine (all replicas).
**Return type:**
None
### `dp` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.dp}
> dp: [int](https://docs.python.org/3/library/functions.html#int)
Number of DP replicas.
### `inflight_send_transfers` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.inflight_send_transfers}
> inflight\_send\_transfers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData)]
Map of transfer names to send transfer request data.
### `initiate_send_transfer()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.initiate_send_transfer}
> initiate\_send\_transfer(remote\_metadata, src\_idxs, dst\_idxs, src\_replica\_idx, dst\_replica\_idx)
Initiate a transfer from current engine to remote engine.
The same page indices are broadcast to all TP shards within the source and destination replicas.
**Parameters:**
* remote\_metadata ([KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata)) – Metadata for the remote engine.
* src\_idxs ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – List of indices of the source pages in the current engine.
* dst\_idxs ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – List of indices of the destination pages in the remote engine.
* src\_replica\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Index of the source replica to transfer from.
* dst\_replica\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Index of the destination replica to transfer to.
### `is_complete()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.is_complete}
> is\_complete(transfer\_req)
Checks if a given send or recv transfer is completed.
:::caution Caution
This method is prone to infinite loops. For the transfer to progress,
the remote engine MUST call wait\_recv\_complete. As such, the following
code will hang:
```python
transfer_req = engine_1.write_to(...)
while not engine_1.is_complete(transfer_req):
pass
while not engine_2.is_complete(transfer_req):
pass
```
Instead do:
```python
transfer_req = engine_1.write_to(...)
while not engine_1.is_complete(transfer_req) or not engine_2.is_complete(transfer_req):
pass
```
:::
**Parameters:**
transfer\_req ([TransferReqData](#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData)) – The transfer request.
**Returns:**
True if all transfers have completed; false otherwise.
### `memory_type` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.memory_type}
> memory\_type: MemoryType
Type of memory being managed (e.g. DRAM).
### `metadata` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.metadata}
> property metadata: [KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata)
Get metadata for all replicas.
**Returns:**
Metadata for the entire engine (all replicas).
### `name` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.name}
> name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Name of transfer engine / nixl agent.
### `remote_agent_to_engine` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.remote_agent_to_engine}
> remote\_agent\_to\_engine: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str)]
Map of remote agent names to their engine names.
### `remote_connections` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.remote_connections}
> remote\_connections: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [KVTransferEngineMetadata](#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata)]
Map of remote engine names to their metadata.
### `sync_and_release()` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.sync_and_release}
> sync\_and\_release(transfer\_req)
Wait for a transfer to complete and release the transfer after it completes.
### `total_num_pages` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.total_num_pages}
> total\_num\_pages: [int](https://docs.python.org/3/library/functions.html#int)
Total number of pages in each tensor (same across all replicas).
### `tp` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngine.tp}
> tp: [int](https://docs.python.org/3/library/functions.html#int)
Number of TP shards per replica.
## `KVTransferEngineMetadata` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata}
> class max.kv\_cache.paged\_kv\_cache.transfer\_engine.KVTransferEngineMetadata(\*, name, total\_num\_pages, bytes\_per\_page, memory\_type, hostname, agents\_meta)
Metadata associated with a transfer engine.
This is safe to send between threads/processes.
### `bytes_per_page` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.bytes_per_page}
> bytes\_per\_page: [int](https://docs.python.org/3/library/functions.html#int)
Bytes per page for each tensor.
### `hostname` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.hostname}
> hostname: [str](https://docs.python.org/3/library/stdtypes.html#str)
Hostname of the machine that the transfer engine is running on.
### `memory_type` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.memory_type}
> memory\_type: MemoryType
Memory type of the transfer engine.
### `name` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.name}
> name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Base name of the transfer engine.
### `total_num_pages` {#max.kv_cache.paged_kv_cache.transfer_engine.KVTransferEngineMetadata.total_num_pages}
> total\_num\_pages: [int](https://docs.python.org/3/library/functions.html#int)
Total number of pages in each tensor.
## `TensorAgent` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent}
> class max.kv\_cache.paged\_kv\_cache.transfer\_engine.TensorAgent(agent, agent\_name, tensor, base\_addr, ucx\_backend, device\_id, agent\_metadata, reg\_dlist)
Manages a single tensor and its associated NIXL agent for transfers.
This class holds both the runtime state (live objects) and can generate
the serializable metadata for communication between engines.
### `device_id` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.device_id}
> device\_id: [int](https://docs.python.org/3/library/functions.html#int)
Device ID for this tensor.
### `reg_dlist` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.reg_dlist}
> reg\_dlist: RegistrationDescriptorList
Registration descriptor list for this tensor.
### `tensor` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.tensor}
> tensor: [Buffer](../../driver.md#max.driver.Buffer)
Tensor for this agent.
### `to_metadata()` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.to_metadata}
> to\_metadata()
Convert to serializable metadata for communication.
### `ucx_backend` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgent.ucx_backend}
> ucx\_backend: [int](https://docs.python.org/3/library/functions.html#int)
UCX backend for this tensor.
## `TensorAgentMetadata` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata}
> class max.kv\_cache.paged\_kv\_cache.transfer\_engine.TensorAgentMetadata(\*, agent\_name, metadata, base\_addr, device\_id)
Metadata for a single tensor/agent in the transfer engine.
This is used for serialization and communication between engines.
### `agent_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.agent_name}
> agent\_name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Name of this agent.
### `base_addr` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.base_addr}
> base\_addr: [int](https://docs.python.org/3/library/functions.html#int)
Base memory address for this tensor.
### `device_id` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.device_id}
> device\_id: [int](https://docs.python.org/3/library/functions.html#int)
Device ID for this tensor.
### `metadata` {#max.kv_cache.paged_kv_cache.transfer_engine.TensorAgentMetadata.metadata}
> metadata: [bytes](https://docs.python.org/3/library/stdtypes.html#bytes)
Metadata for this agent.
## `TransferReqData` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData}
> class max.kv\_cache.paged\_kv\_cache.transfer\_engine.TransferReqData(\*, dst\_name, src\_name, transfer\_name, transfer\_ids, src\_idxs, dst\_idxs, src\_replica\_idx, dst\_replica\_idx)
Metadata associated with a transfer request.
This is safe to send between threads/processes.
### `dst_idxs` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.dst_idxs}
> dst\_idxs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Length of destination indices can differ from len(transfer\_ids).
### `dst_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.dst_name}
> dst\_name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Base name of destination engine.
### `dst_replica_idx` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.dst_replica_idx}
> dst\_replica\_idx: [int](https://docs.python.org/3/library/functions.html#int)
Index of the destination replica this transfer is to.
### `src_idxs` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.src_idxs}
> src\_idxs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Length of source indices can differ from len(transfer\_ids).
### `src_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.src_name}
> src\_name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Base name of source engine.
### `src_replica_idx` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.src_replica_idx}
> src\_replica\_idx: [int](https://docs.python.org/3/library/functions.html#int)
Index of the source replica this transfer is from.
### `transfer_ids` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.transfer_ids}
> transfer\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Transfer IDs (one per TP shard in the replica).
### `transfer_name` {#max.kv_cache.paged_kv_cache.transfer_engine.TransferReqData.transfer_name}
> transfer\_name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Transfer name.
## `available_port()` {#max.kv_cache.paged_kv_cache.transfer_engine.available_port}
> max.kv\_cache.paged\_kv\_cache.transfer\_engine.available\_port(start\_port=8000, end\_port=9000, max\_attempts=100)
Find an available TCP port in the given range.
**Parameters:**
* start\_port ([int](https://docs.python.org/3/library/functions.html#int)) – The lower bound of the port range (inclusive).
* end\_port ([int](https://docs.python.org/3/library/functions.html#int)) – The upper bound of the port range (inclusive).
* max\_attempts ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum number of attempts to find a free port.
## `load_kv_manager()` {#max.kv_cache.registry.load_kv_manager}
> max.kv\_cache.registry.load\_kv\_manager(params, max\_batch\_size, max\_seq\_len, session, available\_cache\_memory)
Loads a single KV cache manager from the given params.
---
## Embedding (Nn)
## `Embedding` {#max.nn.Embedding}
> class max.nn.Embedding(vocab\_size, \*, dim=None, dims=None)
A vector embedding.
An embedding can be thought of as a lookup table for vectors by index.
Given an input tensor of indices into the embedding, the result
of the embedding lookup is a tensor of the same shape, but with each index
replaced by the value of the vector in that location in the embedding table.
The common case for embeddings is a 1-dimensional embedding:
```python
from max.dtype import DType
from max.tensor import Tensor
from max.nn import Embedding
embedding = Embedding(vocab_size=1000, dim=128)
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 128]
```
However they just as easily support multi-dimensional embeddings:
```python
from max.dtype import DType
from max.tensor import Tensor
from max.nn import Embedding
embedding = Embedding(vocab_size=1000, dims=[16, 128])
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 16, 128]
```
### `dim` {#max.nn.Embedding.dim}
> property dim: [Dim](../graph/dim.md#max.graph.dim.Dim)
The dimension of the vectors in the embedding (for a 1d embedding).
Raises: For 0- or >1-dimensional embeddings.
### `dims` {#max.nn.Embedding.dims}
> property dims: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Dim](../graph/dim.md#max.graph.dim.Dim)]
The dimensions of the vectors in the embedding.
### `forward()` {#max.nn.Embedding.forward}
> forward(indices)
Applies the vector embedding to the input tensor of indices.
**Parameters:**
indices ([Tensor](../tensor.md#max.tensor.Tensor)) – An integer-valued tensor. Values must be in the range
\[0, vocab\_size) for the embedding.
**Returns:**
A dense tensor made by looking up each index in the vector embedding.
For an input of shape `(*batch, indices)` and an embedding of shape
`(vocab_size, *dims)`, the result will have shape `(*batch, indices, *dims)`.
**Return type:**
[Tensor](../tensor.md#max.tensor.Tensor)
### `vocab_size` {#max.nn.Embedding.vocab_size}
> property vocab\_size: [Dim](../graph/dim.md#max.graph.dim.Dim)
The vocab size of the embedding.
Indices outside the range of \[0, index\_size) are illegal.
### `weight` {#max.nn.Embedding.weight}
> weight: [Tensor](../tensor.md#max.tensor.Tensor)
:::note Note
For the legacy graph-based embedding layer, see [legacy/embedding](/max/api/python/nn/legacy/embedding).
:::
---
## Linear
## `Linear` {#max.nn.Linear}
> class max.nn.Linear(in\_dim, out\_dim, \*, bias=True)
A unary linear transformation over an input tensor.
Linear is defined as f(x) = x @ W\.T + B where W is the
weight tensor and B is an optional bias tensor.
If W is not square then the transformation represents a
dimensionality change. By convention the weight tensor is stored
transposed.
```python
from max.nn import Linear
from max.tensor import Tensor
model = Linear(5, 10)
assert dict(model.parameters) == {
"weight": model.weight, "bias": model.bias
}
result = model(Tensor.ones([5]))
assert result.shape == [10]
```
### `bias` {#max.nn.Linear.bias}
> bias: [Tensor](../tensor.md#max.tensor.Tensor) | [Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\[0]
The bias `Tensor` for the linear transformation (or 0 if bias is disabled).
### `forward()` {#max.nn.Linear.forward}
> forward(x)
Applies a linear transformation to the input tensor.
Linear is defined as f(x) = x @ W\.T + B where W is the
weight tensor and B is an optional bias tensor.
**Parameters:**
x ([Tensor](../tensor.md#max.tensor.Tensor)) – The input tensor
**Returns:**
The result of applying the linear transformation to the tensor.
**Return type:**
[Tensor](../tensor.md#max.tensor.Tensor)
### `in_dim` {#max.nn.Linear.in_dim}
> property in\_dim: [Dim](../graph/dim.md#max.graph.dim.Dim)
The input dimension for the transformation.
### `out_dim` {#max.nn.Linear.out_dim}
> property out\_dim: [Dim](../graph/dim.md#max.graph.dim.Dim)
The output dimension for the transformation.
### `weight` {#max.nn.Linear.weight}
> weight: [Tensor](../tensor.md#max.tensor.Tensor)
The weight `Tensor` for the linear transformation.
:::note Note
For the legacy graph-based linear layer, see [legacy/linear](/max/api/python/nn/legacy/linear).
:::
---
## nn
APIs to build neural network components for deep learning models with Python.
The MAX neural network API provides two namespaces:
* **max.nn**: Eager-style execution.
* **max.nn.legacy**: Legacy graph-based API (for backward compatibility).
For functional operations like relu, softmax, and more, see the
[`functional`](/max/api/python/functional) module.
## Core API
Use these modules for all models. They provide eager-style execution with
PyTorch-style syntax.
* [`Embedding`](/max/api/python/nn/Embedding): Vector embedding layer for token representation.
* [`Linear`](/max/api/python/nn/Linear): Linear transformation layer with weights and bias.
* [`module`](/max/api/python/nn/module): Base class for all neural network modules.
* [`norm`](/max/api/python/nn/norm): Normalization layers for training stability.
* [`rope`](/max/api/python/nn/rope): Rotary position embeddings for sequence models.
* [`sequential`](/max/api/python/nn/sequential): Containers for composing modules sequentially.
## Legacy API
:::note Note
The legacy API remains available for backward compatibility. For all new models,
use the max.nn API.
:::
The legacy API provides graph-based layer implementations. See the full
reference:
* [`legacy`](/max/api/python/nn/legacy): Neural network legacy API documentation.
---
## attention_with_rope
An opaque KV Cache optimized attention mechanism with Rope.
## `AttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope}
> class max.nn.legacy.attention.attention\_with\_rope.AttentionWithRope(\*, rope, sharding\_strategy=None, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None, use\_qk\_norm=False, rms\_norm\_eps=1e-06)
Implementation of attention that uses Rotary Position Embedding (RoPE).
### `qkv_input_scale` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.qkv_input_scale}
> property qkv\_input\_scale: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None)
The max of q, k, and v scale input vectors.
### `qkv_weight_scale` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.qkv_weight_scale}
> property qkv\_weight\_scale: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
The max of q, k, and v scale weight vectors.
### `qkv_weight_scale_2` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.qkv_weight_scale_2}
> property qkv\_weight\_scale\_2: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None)
The max of q, k, and v scale input vectors.
### `rope` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.rope}
> rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)
### `shard()` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.shard}
> shard(devices)
Create sharded views across devices (tensor-parallel).
Returns one AttentionWithRope per device with appropriately sliced weights.
### `sharding_strategy` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.sharding_strategy}
> property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None)
Get the Module sharding strategy.
### `wqkv` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.wqkv}
> property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
The concatenation of q, k, and v weight vectors.
### `wqkv_bias` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRope.wqkv_bias}
> property wqkv\_bias: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None)
The concatenation of q, k, and v bias weight vectors.
## `AttentionWithRopeNoOpaque` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRopeNoOpaque}
> class max.nn.legacy.attention.attention\_with\_rope.AttentionWithRopeNoOpaque(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, scale=None)
Attention with RoPE without opaque KV cache.
Assumes:
: - no float8
* no stacked qkv
* no bias
* no clip\_qkv
* no float8\_config
### `rope` {#max.nn.legacy.attention.attention_with_rope.AttentionWithRopeNoOpaque.rope}
> rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)
## `DataParallelAttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.DataParallelAttentionWithRope}
> class max.nn.legacy.attention.attention\_with\_rope.DataParallelAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None, use\_qk\_norm=False, rms\_norm\_eps=1e-06)
Data-parallel implementation of Attention with RoPE.
This replicates the attention module across devices and runs each replica on
its local inputs (x, kv, freqs\_cis, input\_row\_offsets). No collective ops
are required; KV-cache remains local to each device.
**Notes:**
* Assumes the caller has already distributed xs, kv\_collections,
freqs\_cis, and input\_row\_offsets so that index i corresponds to
device i, with input\_row\_offsets\[i] rebased to start at 0.
### `rope` {#max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope.rope}
> rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)
### `wqkv` {#max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope.wqkv}
> property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
The concatenation of q, k, and v weight vectors.
### `wqkv_bias` {#max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope.wqkv_bias}
> property wqkv\_bias: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None)
The concatenation of q, k, and v bias weight vectors.
## `GPTQAttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.GPTQAttentionWithRope}
> class max.nn.legacy.attention.attention\_with\_rope.GPTQAttentionWithRope(quantization\_config, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, scale=None, linear\_cls=\)
Implementation of the GPTQ attention layer.
### `wqkv` {#max.nn.legacy.attention.attention_with_rope.GPTQAttentionWithRope.wqkv}
> property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
The concatenation of q, k, and v weight vectors (packed + scales).
## `TensorParallelAttentionWithRope` {#max.nn.legacy.attention.attention_with_rope.TensorParallelAttentionWithRope}
> class max.nn.legacy.attention.attention\_with\_rope.TensorParallelAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None, use\_qk\_norm=False, rms\_norm\_eps=1e-06)
Tensor-parallel wrapper that delegates sharding to the base module.
### `attention_mask_variant` {#max.nn.legacy.attention.mask_config.MHAMaskConfig.attention_mask_variant}
> attention\_mask\_variant: [AttentionMaskVariant](#max.nn.legacy.attention.mask_config.AttentionMaskVariant)
### `positional_encoding_variant` {#max.nn.legacy.attention.mask_config.MHAMaskConfig.positional_encoding_variant}
> positional\_encoding\_variant: [PositionalEncodingVariant](#max.nn.legacy.attention.mask_config.PositionalEncodingVariant)
## `MHAMaskVariant` {#max.nn.legacy.attention.mask_config.MHAMaskVariant}
> class max.nn.legacy.attention.mask\_config.MHAMaskVariant(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
### `CAUSAL_ALIBI_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.CAUSAL_ALIBI_MASK}
> CAUSAL\_ALIBI\_MASK = '1'
### `CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.CAUSAL_MASK}
> CAUSAL\_MASK = '0'
### `CHUNKED_CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.CHUNKED_CAUSAL_MASK}
> CHUNKED\_CAUSAL\_MASK = '3'
### `NULL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.NULL_MASK}
> NULL\_MASK = '2'
### `SLIDING_WINDOW_CAUSAL_MASK` {#max.nn.legacy.attention.mask_config.MHAMaskVariant.SLIDING_WINDOW_CAUSAL_MASK}
> SLIDING\_WINDOW\_CAUSAL\_MASK = '4'
## `PositionalEncodingVariant` {#max.nn.legacy.attention.mask_config.PositionalEncodingVariant}
> class max.nn.legacy.attention.mask\_config.PositionalEncodingVariant(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
### `ALIBI_POS` {#max.nn.legacy.attention.mask_config.PositionalEncodingVariant.ALIBI_POS}
> ALIBI\_POS = 'alibi\_pos'
### `NO_POS` {#max.nn.legacy.attention.mask_config.PositionalEncodingVariant.NO_POS}
> NO\_POS = 'no\_pos'
---
## multi_latent_attention
An opaque KV Cache optimized attention mechanism with Rope.
## `DataParallelLatentAttentionWithRope` {#max.nn.legacy.attention.multi_latent_attention.DataParallelLatentAttentionWithRope}
> class max.nn.legacy.attention.multi\_latent\_attention.DataParallelLatentAttentionWithRope(\*\*kwargs)
Data-parallel implementation of Latent Attention with RoPE.
This replicates the attention module across devices and runs each replica on
its local inputs (x, kv, freqs\_cis, input\_row\_offsets). No collective ops
are required; KV-cache remains local to each device.
**Notes:**
* signal\_buffers is accepted for interface parity with the distributed
implementation but is not used here.
* Assumes the caller has already distributed xs, kv\_collections,
freqs\_cis, and input\_row\_offsets so that index i corresponds to
device i, with input\_row\_offsets\[i] rebased to start at 0.
### `create_mla_prefill_metadata()` {#max.nn.legacy.attention.multi_latent_attention.DataParallelLatentAttentionWithRope.create_mla_prefill_metadata}
> create\_mla\_prefill\_metadata(input\_row\_offsets\_, kv\_collections)
## `LatentAttentionWithRope` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope}
> class max.nn.legacy.attention.multi\_latent\_attention.LatentAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, dtype, devices=None, linear\_cls=\, o\_proj\_dtype=None, o\_proj\_float8\_config=None, scale=None, q\_lora\_rank=None, kv\_lora\_rank=512, qk\_nope\_head\_dim=128, qk\_rope\_head\_dim=64, v\_head\_dim=128, buffer\_size=16384, graph\_mode=None)
Implementation of Latent Attention with Rope.
**Parameters:**
* rope ([RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)) – The rope layer to borrow the freqs\_cis value from.
* num\_attention\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – The number of attention heads.
* num\_key\_value\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – Number of key/value heads.
* hidden\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension of the hidden states.
* kv\_params ([KVCacheParams](../kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV Cache Params, including the number of kv heads, the
head dim, and data type.
* dtype ([DType](../../../dtype.md#max.dtype.DType)) – DType of the weights, currently only bfloat16 is supported.
* devices ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/ops.md#max.graph.ops.DeviceRef)] | None) – Device to place the weights and run the computation. If
multiple are provided, the first device is used.
* linear\_cls ([Callable](../../../graph/ops.md#max.graph.ops.Callable)\[..., [Linear](../../Linear.md#max.nn.Linear)]) – Linear class to use for the outputs dense layer.
* o\_proj\_dtype ([DType](../../../dtype.md#max.dtype.DType) | None) – Optional dtype override for the output projection.
* o\_proj\_float8\_config ([Float8Config](../float8_config.md#max.nn.legacy.float8_config.Float8Config) | None) – Optional float8 config for the output projection.
* scale ([float](https://docs.python.org/3/library/functions.html#float) | None) – Value used to scale the results of the attention output.
* q\_lora\_rank ([int](https://docs.python.org/3/library/functions.html#int) | None) – Optional LoRA rank for Q projection.
* kv\_lora\_rank ([int](https://docs.python.org/3/library/functions.html#int)) – LoRA rank for KV projections.
* qk\_nope\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Head dimension for non-positional encoding part.
* qk\_rope\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Head dimension for rope part.
* v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Head dimension for value.
* buffer\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Buffer size for storing the temporal results during
prefill, in unit of tokens.
* graph\_mode ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Pipeline role to use for the attention layer. Should be
“prefill”, “decode”, or “auto”.
### `rope` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.rope}
> rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.legacy.rotary_embedding.RotaryEmbedding)
### `shard()` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.shard}
> shard(devices)
Creates sharded views of this Module across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded LatentAttentionWithRope instances, one for each device.
### `sharding_strategy` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.sharding_strategy}
> property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None)
Get the Module sharding strategy.
### `w_uk_uv` {#max.nn.legacy.attention.multi_latent_attention.LatentAttentionWithRope.w_uk_uv}
> property w\_uk\_uv: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)]
The concatenation of q, k, and v weight vectors.
## `MLAPrefillMetadata` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata}
> class max.nn.legacy.attention.multi\_latent\_attention.MLAPrefillMetadata(buffer\_row\_offsets, cache\_offsets, buffer\_lengths)
Dataclass to hold MLA prefill metadata.
### `buffer_lengths` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata.buffer_lengths}
> buffer\_lengths: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
### `buffer_row_offsets` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata.buffer_row_offsets}
> buffer\_row\_offsets: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
### `cache_offsets` {#max.nn.legacy.attention.multi_latent_attention.MLAPrefillMetadata.cache_offsets}
> cache\_offsets: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
## `TensorParallelLatentAttentionWithRope` {#max.nn.legacy.attention.multi_latent_attention.TensorParallelLatentAttentionWithRope}
> class max.nn.legacy.attention.multi\_latent\_attention.TensorParallelLatentAttentionWithRope(\*\*kwargs)
Distributed tensor parallel implementation of the Latent Attention with
Rope. Note that using tensor parallelism for MLA will cause the KV-cache to
be duplicated across all devices, which is not efficient.
### `create_mla_prefill_metadata()` {#max.nn.legacy.attention.multi_latent_attention.TensorParallelLatentAttentionWithRope.create_mla_prefill_metadata}
> create\_mla\_prefill\_metadata(input\_row\_offsets\_, kv\_collections)
---
## multihead_attention
## `MultiheadAttention` {#max.nn.legacy.attention.multihead_attention.MultiheadAttention}
> class max.nn.legacy.attention.multihead\_attention.MultiheadAttention(num\_attention\_heads, hidden\_size, devices=None, dtype=float32, scale=None, qkv\_has\_bias=False, o\_proj\_has\_bias=False, stacked\_qkv=False)
Multihead attention that handles both single and distributed computation.
### `wqkv` {#max.nn.legacy.attention.multihead_attention.MultiheadAttention.wqkv}
> property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
The concatenation of q, k, and v weight vectors.
### `wqkv_bias` {#max.nn.legacy.attention.multihead_attention.MultiheadAttention.wqkv_bias}
> property wqkv\_bias: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None)
The concatenation of q, k, and v bias weight vectors.
---
## ragged_attention
An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel.
## `RaggedAttention` {#max.nn.legacy.attention.ragged_attention.RaggedAttention}
> class max.nn.legacy.attention.ragged\_attention.RaggedAttention(\*, mask\_variant, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, clip\_qkv=None)
Layer that computes the self attention score for ragged inputs.
### `wqkv` {#max.nn.legacy.attention.ragged_attention.RaggedAttention.wqkv}
> property wqkv: [TensorValue](../../../graph/TensorValue.md#max.graph.TensorValue)
The concatenation of q, k, and v weight vectors.
---
## clamp
## `clamp()` {#max.nn.legacy.clamp.clamp}
> max.nn.legacy.clamp.clamp(x, min=None, max=None)
Clamps values in `x` to `[min, max]`.
**Parameters:**
* x ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor to clamp.
* min ([float](https://docs.python.org/3/library/functions.html#float) | None) – Minimum value. If `None`, no lower bound is applied.
* max ([float](https://docs.python.org/3/library/functions.html#float) | None) – Maximum value. If `None`, no upper bound is applied.
---
## comm
## `Allreduce` {#max.nn.legacy.comm.Allreduce}
> class max.nn.legacy.comm.Allreduce(num\_accelerators)
Layer to perform allreduce operation with automatic implementation selection.
Automatically chooses between peer-to-peer optimized allreduce and naive
device-to-device transfer based on accelerator connectivity.
**Parameters:**
num\_accelerators ([int](https://docs.python.org/3/library/functions.html#int)) – Number of accelerators participating in the allreduce operation
### `devices` {#max.nn.legacy.comm.Allreduce.devices}
> devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Accelerator](../../../driver.md#max.driver.Accelerator)]
List of accelerators involved in the allreduce operation.
## `Signals` {#max.nn.legacy.comm.Signals}
> class max.nn.legacy.comm.Signals(devices)
Signal buffers used for peer-to-peer communication in allreduce.
Device code uses these buffers by enabling peer-to-peer access.
Then thread blocks use the buffers to implement barriers for
synchronization, and to hold intermediate communication results.
### `NUM_BYTES` {#max.nn.legacy.comm.Signals.NUM_BYTES}
> NUM\_BYTES = 537919488
The size of the signal buffers used for communication in allreduce.
### `buffers()` {#max.nn.legacy.comm.Signals.buffers}
> buffers()
Allocates and returns buffers used for communication in allreduce.
Synchronizes so that buffers are ready for use when this method
returns.
### `devices` {#max.nn.legacy.comm.Signals.devices}
> devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)]
List of graph devices that these signals communicate between.
### `input_types()` {#max.nn.legacy.comm.Signals.input_types}
> input\_types()
Gets graph input types corresponding to these signal buffers.
### `bias` {#max.nn.legacy.conv.Conv1D.bias}
> bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None
The optional bias vector stored on CPU with shape (out\_channels,).
Model init moves the bias to [`device`](#max.nn.legacy.conv.Conv1D.device) if present.
### `device` {#max.nn.legacy.conv.Conv1D.device}
> device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None)
The device where matrix operations are performed.
### `dilation` {#max.nn.legacy.conv.Conv1D.dilation}
> dilation: [int](https://docs.python.org/3/library/functions.html#int)
Controls the dilation rate.
### `filter` {#max.nn.legacy.conv.Conv1D.filter}
> filter: [Weight](../../graph/Weight.md#max.graph.Weight)
The weight matrix stored on CPU with shape (kernel\_size, in\_channels / num\_groups, out\_channels).
Model init moves the weight to [`device`](#max.nn.legacy.conv.Conv1D.device).
### `num_groups` {#max.nn.legacy.conv.Conv1D.num_groups}
> num\_groups: [int](https://docs.python.org/3/library/functions.html#int)
Number of blocked connections from input channels to output channels.
### `padding` {#max.nn.legacy.conv.Conv1D.padding}
> padding: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the amount of padding applied to the input.
If int: symmetric padding applied to both sides (pad\_left = pad\_right = padding).
If tuple\[int, int]: asymmetric padding as (pad\_left, pad\_right).
### `permute` {#max.nn.legacy.conv.Conv1D.permute}
> permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False
bool controls whether self.filter is permuted from PyTorch order to max order.
PyTorch order is: (out\_channels, in\_channels / num\_groups, kernel\_size)
Max API order: (kernel\_size, in\_channels / num\_groups, out\_channels).
### `stride` {#max.nn.legacy.conv.Conv1D.stride}
> stride: [int](https://docs.python.org/3/library/functions.html#int)
Controls the stride for the cross-correlation.
## `Conv2d` {#max.nn.legacy.conv.Conv2d}
> class max.nn.legacy.conv.Conv2d(kernel\_size, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, num\_groups=1, device=None, has\_bias=False, permute=False, name=None)
A 2D convolution over an input signal composed of several input
planes.
**Example:**
```python
conv = nn.Conv2d(
kernel_size=3,
in_channels=64,
out_channels=128,
dtype=DType.float32,
stride=1,
padding=0,
has_bias=False,
name="conv2d_weight",
device=DeviceRef.GPU(),
)
```
### `bias` {#max.nn.legacy.conv.Conv2d.bias}
> bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None
The optional bias vector stored on CPU with shape (out\_channels,).
Model init moves the bias to [`device`](#max.nn.legacy.conv.Conv2d.device) if present.
### `device` {#max.nn.legacy.conv.Conv2d.device}
> device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None)
The device where matrix operations are performed.
### `dilation` {#max.nn.legacy.conv.Conv2d.dilation}
> dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the dilation rate.
### `filter` {#max.nn.legacy.conv.Conv2d.filter}
> filter: [Weight](../../graph/Weight.md#max.graph.Weight)
The weight matrix stored on CPU with shape (height, width, in\_channels / num\_groups, out\_channels).
Model init moves the weight to [`device`](#max.nn.legacy.conv.Conv2d.device).
### `num_groups` {#max.nn.legacy.conv.Conv2d.num_groups}
> num\_groups: [int](https://docs.python.org/3/library/functions.html#int)
Number of blocked connections from input channels to output channels.
### `padding` {#max.nn.legacy.conv.Conv2d.padding}
> padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the amount of padding applied before and after the input for height and width dimensions.
Format: (pad\_top, pad\_bottom, pad\_left, pad\_right).
### `permute` {#max.nn.legacy.conv.Conv2d.permute}
> permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False
bool controls whether self.filter is permuted from PyTorch order to max order.
PyTorch order is: (out\_channels, in\_channels / num\_groups, height, width)
Max API order: (height, width, in\_channels / num\_groups, out\_channels).
### `shard()` {#max.nn.legacy.conv.Conv2d.shard}
> shard(devices)
Creates sharded views of this Conv2d layer across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded Conv2d instances, one for each device.
### `bias` {#max.nn.legacy.conv.Conv3D.bias}
> bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None
The optional bias vector stored on CPU with shape (out\_channels,).
Model init moves the bias to [`device`](#max.nn.legacy.conv.Conv3D.device) if present.
### `device` {#max.nn.legacy.conv.Conv3D.device}
> device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None)
The device where matrix operations are performed.
### `dilation` {#max.nn.legacy.conv.Conv3D.dilation}
> dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the dilation rate for depth, height, and width dimensions.
### `filter` {#max.nn.legacy.conv.Conv3D.filter}
> filter: [Weight](../../graph/Weight.md#max.graph.Weight)
The weight matrix stored on CPU with shape (depth, height, width, in\_channels / num\_groups, out\_channels).
Model init moves the weight to [`device`](#max.nn.legacy.conv.Conv3D.device).
### `num_groups` {#max.nn.legacy.conv.Conv3D.num_groups}
> num\_groups: [int](https://docs.python.org/3/library/functions.html#int)
Number of blocked connections from input channels to output channels.
### `padding` {#max.nn.legacy.conv.Conv3D.padding}
> padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the amount of padding applied before and after the input for depth, height, and width dimensions.
Format: (pad\_front, pad\_back, pad\_top, pad\_bottom, pad\_left, pad\_right).
### `permute` {#max.nn.legacy.conv.Conv3D.permute}
> permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False
bool controls whether self.filter is permuted from PyTorch order to max order.
PyTorch order is: (out\_channels, in\_channels / num\_groups, depth, height, width)
Max API order: (depth, height, width, in\_channels / num\_groups, out\_channels).
### `stride` {#max.nn.legacy.conv.Conv3D.stride}
> stride: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the stride for the cross-correlation.
---
## conv_transpose
## `ConvTranspose1d` {#max.nn.legacy.conv_transpose.ConvTranspose1d}
> class max.nn.legacy.conv\_transpose.ConvTranspose1d(length, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, output\_padding=0, device=None, has\_bias=False, permute=False, name=None)
A 1D transposed convolution operator over an input image composed of several input planes.
```python
conv = nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
has_bias=False,
name="conv3d_weight",
device=DeviceRef.GPU(),
)
```
### `bias` {#max.nn.legacy.conv_transpose.ConvTranspose1d.bias}
> bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None
The optional bias vector stored on CPU with shape (out\_channels,).
Model init moves the bias to [`device`](#max.nn.legacy.conv_transpose.ConvTranspose1d.device) if present.
### `device` {#max.nn.legacy.conv_transpose.ConvTranspose1d.device}
> device: [DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None)
The device where matrix operations are performed.
### `dilation` {#max.nn.legacy.conv_transpose.ConvTranspose1d.dilation}
> dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Not implemented yet. Assuming dilation = 1 for now.
### `output_padding` {#max.nn.legacy.conv_transpose.ConvTranspose1d.output_padding}
> output\_padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
0
**Type:**
Additional size added to one side of the output shape. Default
### `padding` {#max.nn.legacy.conv_transpose.ConvTranspose1d.padding}
> padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the amount of padding applied before and after the input for depth, height, and width dimensions.
### `permute` {#max.nn.legacy.conv_transpose.ConvTranspose1d.permute}
> permute: [bool](https://docs.python.org/3/library/functions.html#bool)
bool controls whether self.weight is permuted from PyTorch order to max order.
PyTorch order is: (in\_channels, out\_channels, kernel\_length)
Max API order: (kernel\_length, out\_channels, in\_channels).
### `stride` {#max.nn.legacy.conv_transpose.ConvTranspose1d.stride}
> stride: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Controls the stride for the cross-correlation.
### `weight` {#max.nn.legacy.conv_transpose.ConvTranspose1d.weight}
> weight: [Weight](../../graph/Weight.md#max.graph.Weight)
The weight matrix stored on CPU with shape (kernel\_length, out\_channels, in\_channels).
Model init moves the weight to [`device`](#max.nn.legacy.conv_transpose.ConvTranspose1d.device).
## `WeightNormConvTranspose1d` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d}
> class max.nn.legacy.conv\_transpose.WeightNormConvTranspose1d(length, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, output\_padding=0, device=None, has\_bias=False, permute=False, name=None)
A 1D transposed convolution operator over an input image composed of several input planes.
This version uses weight normalization as described in .
Weight normalization reparameterizes weights in terms of a direction vector `v` and a magnitude scalar `g`.
This can help improve optimization by decoupling the length and direction of weight vectors.
For example:
: \`\`\`python
conv = WeightNormConvTranspose1d(
length=kernel\_size,
in\_channels=in\_channels,
out\_channels=out\_channels,
dtype=dtype,
stride=stride,
padding=padding,
output\_padding=output\_padding,
has\_bias=False,
device=DeviceRef.GPU(),
)
```
### `conv` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.conv}
> conv: [ConvTranspose1d](#max.nn.legacy.conv_transpose.ConvTranspose1d)
The underlying ConvTranspose1d layer.
### `device` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.device}
> device: [DeviceRef](../../graph/type.md#max.graph.type.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None)
The device where matrix operations are performed.
### `weight_g` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.weight_g}
> weight_g: [Weight](../../graph/Weight.md#max.graph.Weight)
The magnitude parameter g for weight normalization.
### `weight_v` {#max.nn.legacy.conv_transpose.WeightNormConvTranspose1d.weight_v}
> weight_v: [Weight](../../graph/Weight.md#max.graph.Weight)
The direction parameter v for weight normalization.
```
---
## embedding (Legacy)
The `embedding` module provides classes for mapping integer indices (like
token IDs) to dense vector representations. These embedding operations are
fundamental building blocks for natural language processing, recommendation
systems, and other tasks involving discrete tokens.
* `Embedding`: Basic embedding lookup table for simple use cases
* `EmbeddingV2`: Enhanced embedding with device placement control and improved memory management
* `VocabParallelEmbedding`: Distributed embedding that shards the vocabulary across multiple devices for large embedding tables
Here’s an example demonstrating how to use embeddings:
```python
import max.nn as nn
from max.graph import Graph, ops, DeviceRef
from max.dtype import DType
import numpy as np
with Graph(name="embedding_example") as graph:
# Define dimensions
batch_size = 4
seq_length = 16
vocab_size = 10000
hidden_dim = 256
# Create input tensor of token indices
input_data = np.random.randint(0, vocab_size, (batch_size, seq_length), dtype=np.int32)
input_indices = ops.constant(input_data, dtype=DType.int32, device=DeviceRef.CPU())
# Create embedding layer
embedding = nn.EmbeddingV2(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="token_embeddings"
)
# Look up embeddings for input indices
embeddings = embedding(input_indices)
print(f"Embedding output shape: {embeddings.shape}")
# Embedding output shape: [Dim(4), Dim(16), Dim(256)]
```
## `Embedding` {#max.nn.legacy.embedding.Embedding}
> class max.nn.legacy.embedding.Embedding(vocab\_size, hidden\_dim, dtype, device, quantization\_encoding=None, name=None)
A lookup table for embedding integer indices into dense vectors.
This layer maps each integer index to a dense vector of fixed size.
Embedding weights are stored on the CPU but are moved to the specified
device during the model init phase.
Example:
```python
embedding_layer = Embedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="embeddings",
)
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
```
### `device` {#max.nn.legacy.embedding.Embedding.device}
> device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)
The device on which embedding lookup is performed.
### `weight` {#max.nn.legacy.embedding.Embedding.weight}
> weight: [Weight](../../graph/Weight.md#max.graph.Weight)
The embedding weight matrix stored on the CPU.
Model init moves weights to the device specified in [`device`](#max.nn.legacy.embedding.Embedding.device).
## `VocabParallelEmbedding` {#max.nn.legacy.embedding.VocabParallelEmbedding}
> class max.nn.legacy.embedding.VocabParallelEmbedding(vocab\_size, hidden\_dim, dtype, devices, quantization\_encoding=None, name=None)
A lookup table for embedding integer indices into dense vectors.
This layer works like nn.Embedding except the embedding table is sharded
on the vocabulary dimension across all devices.
Example:
```python
embedding_layer = VocabParallelEmbedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=[DeviceRef.GPU(0), DeviceRef.GPU(1)],
name="embeddings",
)
# Token indices of shape: [batch, ..., num_indices].
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
```
---
## float8_config
Float8 configuration data structures for models.
## `Float8Config` {#max.nn.legacy.float8_config.Float8Config}
> class max.nn.legacy.float8\_config.Float8Config(input\_scale, weight\_scale, mlp\_in\_float8, attn\_qkv\_in\_float8, embedding\_output\_dtype=None, bias\_dtype=None, quant\_method=None, quant\_algo=None)
Configures float8 quantization settings for a layer or model section.
### `attn_qkv_in_float8` {#max.nn.legacy.float8_config.Float8Config.attn_qkv_in_float8}
> attn\_qkv\_in\_float8: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]
Set of layer indices with attention QKV projections in float8.
QKV projections are considered to be either “all quantized” or all not
quantized per layer.
So either all of {q,k,v,o}\_proj are float8, or all bfloat16.
### `bias_dtype` {#max.nn.legacy.float8_config.Float8Config.bias_dtype}
> bias\_dtype: [DType](../../dtype.md#max.dtype.DType) | [None](https://docs.python.org/3/library/constants.html#None) = None
The `DType` of bias weights.
### `embedding_output_dtype` {#max.nn.legacy.float8_config.Float8Config.embedding_output_dtype}
> embedding\_output\_dtype: [DType](../../dtype.md#max.dtype.DType) | [None](https://docs.python.org/3/library/constants.html#None) = None
The `DType` of the output from the embedding layer.
### `input_scale` {#max.nn.legacy.float8_config.Float8Config.input_scale}
> input\_scale: [Float8InputScaleSpec](#max.nn.legacy.float8_config.Float8InputScaleSpec)
[`Float8InputScaleSpec`](#max.nn.legacy.float8_config.Float8InputScaleSpec) for input activation scaling.
### `is_dynamic` {#max.nn.legacy.float8_config.Float8Config.is_dynamic}
> property is\_dynamic: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns `True` if this input scale is dynamic.
### `is_nvfp4` {#max.nn.legacy.float8_config.Float8Config.is_nvfp4}
> property is\_nvfp4: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns `True` if this config represents modelopt NVFP4.
### `is_static` {#max.nn.legacy.float8_config.Float8Config.is_static}
> property is\_static: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns `True` if this input scale is static.
### `mlp_in_float8` {#max.nn.legacy.float8_config.Float8Config.mlp_in_float8}
> mlp\_in\_float8: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]
Set of layer indices with MLPs in float8.
MLPs are considered to be either “all quantized” or all not quantized per
layer.
So either all of gate proj, down proj, and up proj are float8, or all bfloat16.
### `quant_algo` {#max.nn.legacy.float8_config.Float8Config.quant_algo}
> quant\_algo: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None
Additional differentiator within same quant\_method e.g. modelopt NVFP4 vs FP8
### `quant_method` {#max.nn.legacy.float8_config.Float8Config.quant_method}
> quant\_method: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None
The quantization method used (e.g., “fbgemm\_fp8”).
### `quantized_scales_type()` {#max.nn.legacy.float8_config.Float8Config.quantized_scales_type}
> quantized\_scales\_type(quantized\_shape, device\_ref)
Returns the TensorType of the scales tensor after dynamic quantization.
### `scales_granularity_mnk` {#max.nn.legacy.float8_config.Float8Config.scales_granularity_mnk}
> property scales\_granularity\_mnk: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]
Returns the weight and input scale granularities on M, N and K axis.
### `weight_scale` {#max.nn.legacy.float8_config.Float8Config.weight_scale}
> weight\_scale: [Float8WeightScaleSpec](#max.nn.legacy.float8_config.Float8WeightScaleSpec)
[`Float8WeightScaleSpec`](#max.nn.legacy.float8_config.Float8WeightScaleSpec) for weight scaling.
## `Float8InputScaleSpec` {#max.nn.legacy.float8_config.Float8InputScaleSpec}
> class max.nn.legacy.float8\_config.Float8InputScaleSpec(granularity, origin, dtype, activation\_scale\_ub=None, block\_size=None)
Specifies how input activations are scaled for float8 quantization.
---
## hooks
## `PrintHook` {#max.nn.legacy.hooks.PrintHook}
> class max.nn.legacy.hooks.PrintHook(export\_path=None, filter=None)
Hook that prints/saves layer tensor inputs and outputs.
This class must be initialized added before the graph is built so the
print ops can be added to the graph.
### `export_path` {#max.nn.legacy.hooks.PrintHook.export_path}
> property export\_path: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None)
### `name_layers()` {#max.nn.legacy.hooks.PrintHook.name_layers}
> name\_layers(model)
Create names for all layers in the model based on nested attributes.
**Parameters:**
model ([Layer](layer.md#max.nn.legacy.layer.Layer))
**Return type:**
None
### `print_value()` {#max.nn.legacy.hooks.PrintHook.print_value}
> print\_value(name, value)
Prints a value, and returns whether the print is successful.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* value ([Any](https://docs.python.org/3/library/typing.html#typing.Any))
---
## legacy
Legacy graph-based neural network API.
:::note Note
This is the legacy API for backward compatibility. For all new models, use
the eager tensor API from [nn](/max/api/python/nn).
:::
The legacy API provides graph-based layer implementations for building neural
networks. This API was the primary interface prior to MAX 26.1 and remains
available for backward compatibility.
**Using the Legacy API:**
```python
from max.nn.legacy import Module, Linear, LayerNorm
from max.nn.legacy.attention import AttentionWithRope
```
## Modules
* [`attention`](/max/api/python/nn/legacy/attention): Attention mechanisms for sequence modeling.
* [`clamp`](/max/api/python/nn/legacy/clamp): Value clamping utilities for tensor operations.
* [`comm`](/max/api/python/nn/legacy/comm): Communication primitives for distributed training.
* [`conv`](/max/api/python/nn/legacy/conv): Convolutional layers for spatial processing.
* [`conv_transpose`](/max/api/python/nn/legacy/conv_transpose): Transposed convolution for upsampling.
* [`embedding`](/max/api/python/nn/legacy/embedding): Embedding layers with vocabulary support.
* [`float8_config`](/max/api/python/nn/legacy/float8_config): Configuration for FP8 quantization.
* [`hooks`](/max/api/python/nn/legacy/hooks): Extension hooks for layer customization.
* [`kernels`](/max/api/python/nn/legacy/kernels): Custom kernel implementations.
* [`kv_cache`](/max/api/python/nn/legacy/kv_cache): Key-value cache for efficient generation.
* [`layer`](/max/api/python/nn/legacy/layer): Base classes for building graph-based layers.
* [`linear`](/max/api/python/nn/legacy/linear): Linear transformation layers with optional parallelism.
* [`lora`](/max/api/python/nn/legacy/lora): Low-Rank Adaptation for efficient fine-tuning.
* [`moe`](/max/api/python/nn/legacy/moe): Mixture of Experts layer implementations.
* [`norm`](/max/api/python/nn/legacy/norm): Normalization layers for training stability.
* [`rotary_embedding`](/max/api/python/nn/legacy/rotary_embedding): Rotary position embeddings for sequences.
* [`sampling`](/max/api/python/nn/legacy/sampling): Sampling strategies for generation.
* [`sequential`](/max/api/python/nn/legacy/sequential): Container for sequential layer composition.
* [`transformer`](/max/api/python/nn/legacy/transformer): Transformer building blocks and layers.
---
## kernels
Helper functions for wrapping custom kv cache/attention related ops.
## `apply_penalties_to_logits()` {#max.nn.legacy.kernels.apply_penalties_to_logits}
> max.nn.legacy.kernels.apply\_penalties\_to\_logits(logits\_buffer, frequency\_data, frequency\_offsets, \*, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0)
Applies penalties to the logits.
**Parameters:**
* logits\_buffer ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue)) – The buffer to apply penalties to.
* frequency\_data ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – 2d tensor of shape \[unique\_tokens, 2], where
the first column indicates the token id and the second column
indicates the frequency of the token.
* frequency\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – 1d tensor of shape \[batch\_size + 1], indicating
start of each sequence’s data.
* frequency\_penalty (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The frequency penalty to apply to the model’s output.
A positive value will penalize new tokens based on their frequency
in the generated text: tokens will receive a penalty proportional
to the count of appearances.
* presence\_penalty (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The presence penalty to apply to the model’s output
A positive value will penalize new tokens that have already appeared
in the generated text at least once by applying a constant penalty.
* repetition\_penalty (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The repetition penalty to apply to the model’s
output. Values > 1 will penalize new tokens that have already
appeared in prompt and generated text at least once by dividing the
logits by the repetition penalty.
**Return type:**
None
## `batched_dynamic_scaled_fp8_matmul()` {#max.nn.legacy.kernels.batched_dynamic_scaled_fp8_matmul}
> max.nn.legacy.kernels.batched\_dynamic\_scaled\_fp8\_matmul(a, b, a\_scales, b\_scales, input\_scale\_spec, weight\_scale\_spec, out\_type=bfloat16)
Perform a batched blockwise scaled matmul of two tensors with scaling factors.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply (3D tensor).
* b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed (3D tensor).
* a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor (3D tensor).
* b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor (3D tensor).
* input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec))
* weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec))
* out\_type ([DType](../../dtype.md#max.dtype.DType))
## `batched_quantize_dynamic_scaled_float8()` {#max.nn.legacy.kernels.batched_quantize_dynamic_scaled_float8}
> max.nn.legacy.kernels.batched\_quantize\_dynamic\_scaled\_float8(input, input\_scale\_spec, weight\_scale\_spec, scale\_ub=1200.0, group\_size\_or\_per\_token=-1, out\_type=float8\_e4m3fn, scales\_type=bfloat16)
Dynamically quantize the input tensor to fp8.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to quantize. Shape: [batch\_size, seq\_len, hidden\_size]
* scale\_ub ([float](https://docs.python.org/3/library/functions.html#float)) – The upper bound of the scale factor.
* group\_size\_or\_per\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The group size for quantization. When set to -1,
the quantization is column-wise.
* out\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the output tensor.
* scales\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the scales tensor.
* input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec))
* weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec))
## `block_scales_interleave()` {#max.nn.legacy.kernels.block_scales_interleave}
> max.nn.legacy.kernels.block\_scales\_interleave(scales, sf\_vector\_size=16, scales\_type=float8\_e4m3fn)
Interleave the block scales tensor in \[M, N] layout to \[ceildiv(M, 128), ceildiv(N, sf\_vector\_size \* 4), 32, 4, 4] layout.
**Parameters:**
* scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scales tensor to interleave in \[M, N] layout.
* sf\_vector\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The block size for the scaling factors.
* scales\_type ([DType](../../dtype.md#max.dtype.DType))
**Returns:**
The interleaved scales tensor in \[ceildiv(M, 128), ceildiv(N, sf\_vector\_size \* 4), 32, 4, 4] layout.
* n ([Dim](../../graph/dim.md#max.graph.dim.Dim)) – The numerator.
* d ([Dim](../../graph/dim.md#max.graph.dim.Dim)) – The denominator.
**Returns:**
The ceiling of dividing n by d.
**Return type:**
[Dim](../../graph/dim.md#max.graph.dim.Dim)
## `convert_weights_to_fp8_fnuz_if_needed()` {#max.nn.legacy.kernels.convert_weights_to_fp8_fnuz_if_needed}
> max.nn.legacy.kernels.convert\_weights\_to\_fp8\_fnuz\_if\_needed(weight, weight\_scale)
Convert weights and scales to FP8 FNUZ format if needed for AMD GPUs.
This utility function checks if FP8 FNUZ conversion is needed, currently onli AMD MI300 GPUs,
and performs the conversion if required. This centralizes the conversion logic
that was previously duplicated across multiple files.
**Parameters:**
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight tensor to potentially convert.
* weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight scale factor.
**Returns:**
Tuple of (weight, weight\_scale) - converted if needed, original otherwise.
## `cross_attention_ragged()` {#max.nn.legacy.kernels.cross_attention_ragged}
> max.nn.legacy.kernels.cross\_attention\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, kv\_input\_row\_offsets, q\_max\_seq\_len, scale, local\_window\_size=-1)
Computes cross attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant)
within the kernel.
input and input\_row\_offsets are used together to implement the ragged
tensor.
input\_row\_offsets indicates where each batch starts and ends in input
attention, kv\_input\_row\_offsets represents the KV sequence length.
## `dynamic_block_scaled_matmul_fp4()` {#max.nn.legacy.kernels.dynamic_block_scaled_matmul_fp4}
> max.nn.legacy.kernels.dynamic\_block\_scaled\_matmul\_fp4(a, b, a\_scales, b\_scales, tensor\_sf, sf\_vector\_size=16, out\_type=bfloat16)
Perform a matmul of two FP4 tensors with 1D-block scaled scaling factors.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply.
* b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed.
* a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor.
* b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor.
* tensor\_sf ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [float](https://docs.python.org/3/library/functions.html#float)) – Buffer-wise scaling factor equal to weight\_scale\_2 \* input\_scale (non-inverted).
* sf\_vector\_size ([int](https://docs.python.org/3/library/functions.html#int))
* out\_type ([DType](../../dtype.md#max.dtype.DType))
## `dynamic_scaled_matmul()` {#max.nn.legacy.kernels.dynamic_scaled_matmul}
> max.nn.legacy.kernels.dynamic\_scaled\_matmul(a, b, a\_scales, b\_scales, input\_scale\_spec, weight\_scale\_spec, out\_type=bfloat16)
Perform a matmul of two tensors with scaling factors. Currently only
supports channel-wise scaling for weights and per-token scaling for inputs.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply.
* b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed.
* a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor.
* b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor.
* input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec))
* weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec))
* out\_type ([DType](../../dtype.md#max.dtype.DType))
## `flare_mla_decode_ragged()` {#max.nn.legacy.kernels.flare_mla_decode_ragged}
> max.nn.legacy.kernels.flare\_mla\_decode\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, qk\_rope\_dim=64)
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant)
within the kernel.
input and input\_row\_offsets are used together to implement the ragged
tensor.
input\_row\_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is
assumed to be equal to the Q sequence length.
For KV sequence length != Q sequence length, use cross\_attention\_ragged.
## `flare_mla_decompress_k_cache()` {#max.nn.legacy.kernels.flare_mla_decompress_k_cache}
> max.nn.legacy.kernels.flare\_mla\_decompress\_k\_cache(kv\_params, buffer\_row\_offsets\_1d, cache\_offsets\_1d, buffer\_length, weight, kv\_collection, layer\_idx, buffer\_size)
This kernel decompresses the key cache by up-projecting latent representations
into the KV space using a weight matrix.
The process involves:
1. Copying buffer\_length latent vectors from the key cache into a contiguous
buffer (k\_latent)
2. Computing k = k\_latent @ weight.T to obtain the decompressed keys
**Returns:**
A tensor of shape \[buffer\_size, weight.shape\[0]] containing the decompressed
keys. Note that only the first buffer\_length tokens are valid.
## `flare_mla_prefill_plan()` {#max.nn.legacy.kernels.flare_mla_prefill_plan}
> max.nn.legacy.kernels.flare\_mla\_prefill\_plan(kv\_params, input\_row\_offsets, kv\_collection, layer\_idx, buffer\_size, max\_chunks=16)
This kernel plans how to process a batch of sequences with
varying lengths using a fixed-size buffer.
Each sequence in the batch has some existing cached tokens and new input
tokens. The kernel divides the total tokens into chunks of buffer\_size.
For each chunk (iteration), it calculates:
: 1. Buffer offsets for each sequence in each chunk
2\. Cache offsets for each sequence in each chunk
3\. Total buffer lengths for each processing iteration
## `flare_mla_prefill_ragged()` {#max.nn.legacy.kernels.flare_mla_prefill_ragged}
> max.nn.legacy.kernels.flare\_mla\_prefill\_ragged(kv\_params, input, k, v, input\_row\_offsets, buffer\_row\_offsets, cache\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, qk\_rope\_dim=64)
Performs MLA prefill. In the MLA prefill, we need to decompress
the KV tensors, as we store the latent representations in the KV cache.
We will decompress the KV tensors into a fixed size buffer to avoid
out-of-memory errors. In case the total cache length is greater than
the buffer size, we will process the attention calculation in chunks.
This MLA prefill kernel will return the output tensor for this iteration
and the softmax info tensor for this iteration. Such tensors will be used
by the next iteration of the MLA prefill kernel to continue the attention
calculation.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor
* k ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Key tensor
* v ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Value tensor
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each batch starts and ends in input
* buffer\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each batch starts and ends in the buffer
* cache\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each batch starts and ends in the KV cache
* kv\_collection (PagedCacheValues) – KV collection
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index tensor
* mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Mask variant
* scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scale
* qk\_rope\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – QK rope dimension
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Query tensor of shape \[batch, seq\_len, num\_heads, head\_dim]
* k ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Key tensor of shape \[batch, seq\_len, num\_heads, head\_dim]
* v ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Value tensor of shape \[batch, seq\_len, num\_heads, head\_dim]
* mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – The mask variant to use for attention
* scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scaling factor for attention scores
* local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Local window size for sliding window attention
* valid\_length ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional tensor of shape \[batch] with dtype uint32.
When provided, uses the padded kernel variant that respects
the valid sequence lengths for each batch element.
**Returns:**
Output tensor of shape \[batch, seq\_len, num\_heads, head\_dim]
## `flash_attention_ragged()` {#max.nn.legacy.kernels.flash_attention_ragged}
> max.nn.legacy.kernels.flash\_attention\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, local\_window\_size=-1, sink\_weights=None)
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant)
within the kernel.
input and input\_row\_offsets are used together to implement the ragged
tensor.
input\_row\_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is
assumed to be equal to the Q sequence length.
For KV sequence length != Q sequence length, use cross\_attention\_ragged.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams object containing key-value cache parameters.
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input tensor with shape \[total\_seq\_len, hidden\_dim].
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue indicating the start and end of each batch in the input tensor with shape \[batch\_size + 1].
* kv\_collection (PagedCacheValues) – PagedCacheValues object for managing key-value cache.
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the layer index, expected to have dtype uint32.
* mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – MHAMaskVariant specifying the type of attention mask to use.
* scale ([float](https://docs.python.org/3/library/functions.html#float)) – float value used to scale the attention scores.
* local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int)) – int specifying the size of the local attention window, default is -1 for no local window.
* sink\_weights ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional tensor of shape \[num\_heads] containing learnable sink weights for each attention head.
## `flash_attention_ragged_gpu()` {#max.nn.legacy.kernels.flash_attention_ragged_gpu}
> max.nn.legacy.kernels.flash\_attention\_ragged\_gpu(q, k, v, input\_row\_offsets, max\_seq\_len, mask\_variant, scale, local\_window\_size=-1)
Computes flash attention for ragged inputs using GPU-optimized kernel
without a KV cache.
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Query tensor of shape \[total\_seq\_len, num\_heads, head\_dim] (ragged)
* k ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Key tensor of shape \[total\_seq\_len, num\_heads, head\_dim] (ragged)
* v ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Value tensor of shape \[total\_seq\_len, num\_heads, head\_dim] (ragged)
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer of shape \[batch\_size + 1] with dtype uint32.
Indicates where each sequence starts and ends in the ragged tensors.
The values should be a prefix sum (cumulative sum) of sequence lengths.
* mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – The mask variant to use for attention
* scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scaling factor for attention scores
* local\_window\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Local window size for sliding window attention
* max\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Returns:**
Output tensor of shape \[total\_seq\_len, num\_heads, head\_dim]
## `fused_qk_padded_rope()` {#max.nn.legacy.kernels.fused_qk_padded_rope}
> max.nn.legacy.kernels.fused\_qk\_padded\_rope(kv\_params, input, kv\_collection, freqs\_cis, layer\_idx, valid\_lengths, interleaved=True)
Computes fused query-key RoPE with padded inputs and paged KV cache.
This function applies Rotary Positional Embeddings (RoPE) to both Q and K tensors,
where K is stored in the paged KV cache. This is the padded equivalent of
fused\_qk\_ragged\_rope.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters.
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Query tensor of shape \[batch, seq\_len, n\_heads, head\_dim].
* kv\_collection (PagedCacheValues) – Paged KV cache collection.
* freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Frequency tensor of shape (max\_seq\_len \* 2, head\_dim).
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index for KV cache (must be uint32 on CPU).
* valid\_lengths ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer of shape \[batch] containing the valid length for each
sequence (must be uint32). RoPE is only applied to positions within
these lengths.
* interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to use interleaved RoPE pattern.
**Returns:**
Query tensor with RoPE applied, same shape as input.
:::note Note
Unlike fused\_qk\_ragged\_rope which requires ragged inputs, this function
works with padded batch inputs where sequences may have different actual
lengths but are padded to a uniform shape.
:::
## `fused_qk_ragged_rope()` {#max.nn.legacy.kernels.fused_qk_ragged_rope}
> max.nn.legacy.kernels.fused\_qk\_ragged\_rope(kv\_params, input, input\_row\_offsets, kv\_collection, freqs\_cis, layer\_idx, interleaved=True, position\_ids=None, mrope\_section=None)
Computes fused query-key attention with rotary positional encodings and ragged inputs.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – \[batch\_size \* seq\_len, n\_heads, head\_dim]
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Ragged tensor offsets indicating where each batch starts and ends
* kv\_collection (PagedCacheValues) – KV cache collection
* freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – tensor of shape (max\_seq\_len \* 2, head\_dim)
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index for KV cache
* interleaved ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to use interleaved RoPE pattern
* position\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional ragged 2D array of position IDs. If None, defaults to
cache\_length + token\_idx for each token. When num\_sections > 1,
mrope\_section must be provided to indicate each section of the head\_dim
to apply RoPE to. Shape: [num\_sections, total\_seq\_len]
* mrope\_section ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | None) – Optional list of integers indicating the section of the head\_dim to
* position\_ids. (apply RoPE to. Must be used in conjunction with)
input and input\_row\_offsets are used together to implement the ragged tensor.
input\_row\_offsets indicates where each batch starts and ends in input. If input
is not of the same dtype as freqs\_cis, it will be cast to the dtype of freqs\_cis
for the computation, and cast back to the original dtype after the computation is
finished.
When position\_ids and mrope\_section are provided, it replaces the default position
calculation (cache\_length + token\_idx) with explicit position values. This is useful for
3D RoPE in models like Qwen2.5-VL that need custom position encoding.
## `fused_qkv_padded_matmul()` {#max.nn.legacy.kernels.fused_qkv_padded_matmul}
> max.nn.legacy.kernels.fused\_qkv\_padded\_matmul(kv\_params, input, wqkv, kv\_collection, layer\_idx, valid\_lengths, n\_heads)
Computes fused query, key, and value projections with padded input.
This is for non-ragged (padded batch) inputs where sequences may have
different actual lengths but are padded to a uniform shape.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache parameters.
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor with shape \[batch\_size, seq\_len, hidden\_dim].
* wqkv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight tensor for Q, K, V projections.
* kv\_collection (PagedCacheValues) – Paged KV cache collection.
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index for cache lookup (must be uint32).
* valid\_lengths ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Buffer of shape \[batch] containing the valid length for each
sequence (must be uint32). K and V are only written to cache for
positions within these lengths.
* n\_heads ([int](https://docs.python.org/3/library/functions.html#int)) – Number of attention heads.
**Returns:**
Query projections tensor. K and V projections are written to cache.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `fused_qkv_ragged_matmul()` {#max.nn.legacy.kernels.fused_qkv_ragged_matmul}
> max.nn.legacy.kernels.fused\_qkv\_ragged\_matmul(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, bias=None)
Computes fused query, key, and value projections with ragged input.
input and input\_row\_offsets are used together to implement the ragged
tensor.
input\_row\_offsets indicates where each batch starts and ends in input
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `fused_qkv_ragged_matmul_quantized()` {#max.nn.legacy.kernels.fused_qkv_ragged_matmul_quantized}
> max.nn.legacy.kernels.fused\_qkv\_ragged\_matmul\_quantized(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, quantization\_config, perm\_idx=None, bias=None)
Computes fused query, key, and value projections with ragged input and
quantized weight matrices. A quantization\_config must be provided.
input and input\_row\_offsets are used together to implement the ragged
tensor.
input\_row\_offsets indicates where each batch starts and ends in input
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `grouped_dynamic_scaled_fp8_matmul()` {#max.nn.legacy.kernels.grouped_dynamic_scaled_fp8_matmul}
> max.nn.legacy.kernels.grouped\_dynamic\_scaled\_fp8\_matmul(hidden\_states, weight, a\_scales, b\_scales, expert\_start\_indices, expert\_ids, expert\_usage\_stats\_host, input\_scale\_spec, weight\_scale\_spec, out\_type=bfloat16, tokens\_padded\_per\_expert=False)
Grouped blockwise scaled matmul used in MoE layer.
Perform a grouped blockwise scaled matmul of two tensors with scaling factors.
hidden\_states and expert\_start\_indices are used together to implement
the ragged tensor.
**Parameters:**
* hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first tensor to multiply. (2D tensor)
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second tensor to multiply, must be transposed. (3D tensor)
* a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the first tensor. (2D tensor)
* b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scaling factors for the second tensor. (3D tensor)
* expert\_start\_indices ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – indicates where each group starts and ends in hidden\_states.
* expert\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The id of the expert for each group in hidden\_states.
* expert\_usage\_stats\_host ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The maximum number of tokens assigned to any expert, and the number of active experts.
* input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec)) – The scaling granularity for the input tensor.
* weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec)) – The scaling granularity for the weight tensor.
* tokens\_padded\_per\_expert ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True, It’s guaranteed that the number of tokens for each local expert will be
padded, so that a\_scales is aligned to 16 bytes. This is needed by the optimized grouped matmul kernel.
* out\_type ([DType](../../dtype.md#max.dtype.DType))
## `grouped_dynamic_scaled_nvfp4_matmul()` {#max.nn.legacy.kernels.grouped_dynamic_scaled_nvfp4_matmul}
> max.nn.legacy.kernels.grouped\_dynamic\_scaled\_nvfp4\_matmul(hidden\_states, weight, a\_scales, b\_scales, expert\_start\_indices, a\_scale\_offsets, expert\_ids, expert\_scales, expert\_usage\_stats\_host, out\_type=bfloat16)
Performs grouped NVFP4 matmul for MoE layers.
Performs a grouped matmul with NVFP4 (4-bit) quantized inputs and weights.
The inputs are packed as uint8 (2 NVFP4 values per byte) with float8\_e4m3fn
scaling factors. NVFP4 uses fixed 1D block scaling with 16 elements per
scale factor along the K dimension.
`hidden_states` and `expert_start_indices` together implement the ragged
tensor representation for variable-length expert inputs.
**Parameters:**
* hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input activations with shape `[total_tokens, K/2]`
where K is the unpacked hidden dimension. Dtype must be uint8
(packed NVFP4).
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The expert weights with shape `[num_experts, N, K/2]`.
Dtype must be uint8 (packed NVFP4).
* a\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Scaling factors for inputs with shape
`[num_scale_rows, K_groups, 32, 4, 4]`. Dtype must be float8\_e4m3fn.
* b\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Scaling factors for weights with shape
`[num_experts, N_groups, K_groups, 32, 4, 4]`. Dtype must be
float8\_e4m3fn.
* expert\_start\_indices ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indices indicating where each expert’s tokens
start in `hidden_states`.
* a\_scale\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The offsets of the input scale tiles for each expert.
* expert\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The expert ID for each group.
* expert\_scales ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Per-expert scaling factors with shape `[num_experts]`.
Dtype must be float32. Multiplied with the matmul output in the
epilogue.
* expert\_usage\_stats\_host ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – A tensor containing \[max\_tokens\_per\_expert,
num\_active\_experts].
* out\_type ([DType](../../dtype.md#max.dtype.DType)) – Output dtype. Defaults to bfloat16.
* tokens\_padded\_per\_expert – If True, tokens per expert are padded for
alignment. Defaults to False.
**Returns:**
The matmul result with shape `[total_tokens, N]` and dtype `out_type`.
## `grouped_matmul_ragged()` {#max.nn.legacy.kernels.grouped_matmul_ragged}
> max.nn.legacy.kernels.grouped\_matmul\_ragged(hidden\_states, weight, expert\_start\_indices, expert\_ids, expert\_usage\_stats\_host)
Grouped matmul used in MoE layer.
hidden\_states and expert\_start\_indices are used together to implement
the ragged tensor. expert\_start\_indices indicates where each group starts
and ends in hidden\_states
expert\_ids is the id of the expert for each group in hidden\_states
expert\_usage\_stats\_host is the maximum number of tokens assigned to any
expert, and the number of active experts.
## `kv_cache_copy_pages_d2h()` {#max.nn.legacy.kernels.kv_cache_copy_pages_d2h}
> max.nn.legacy.kernels.kv\_cache\_copy\_pages\_d2h(device\_kv\_collection, device\_page\_ids, host\_kv\_blocks, host\_page\_ids, layer\_idx, device\_ref)
Copy KV cache pages from GPU to CPU for a single layer.
Performs async GPU->CPU copy of specified pages for layer-wise KV cache
offloading.
**Parameters:**
* device\_kv\_collection (PagedCacheValues) – Source KV cache on GPU.
* device\_page\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Source page IDs to read from GPU.
* host\_kv\_collection – Destination KV cache on CPU.
* host\_page\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Destination page IDs to write to CPU.
Must have same length as device\_page\_ids.
* layer\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Which layer to copy.
* device\_ref ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)) – Device for the GPU context.
* host\_kv\_blocks ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue))
**Return type:**
None
## `kv_cache_get_max_seq_len()` {#max.nn.legacy.kernels.kv_cache_get_max_seq_len}
> max.nn.legacy.kernels.kv\_cache\_get\_max\_seq\_len(kv\_params, kv\_collection)
This kernel returns the maximum sequence length.
## `kv_cache_ragged_2m_iadd()` {#max.nn.legacy.kernels.kv_cache_ragged_2m_iadd}
> max.nn.legacy.kernels.kv\_cache\_ragged\_2m\_iadd(kv\_params, a, kv\_collection, input\_row\_offsets, lora\_end\_idx, batch\_seq\_len, layer\_idx)
In-place add to paged KV cache with interleaved K/V layout.
Performs an in-place addition of new key-value projections to paged KV cache.
The input tensor a uses a “2M” layout where keys and values are interleaved:
rows \[0, m) contain keys and rows \[m, 2m) contain values, where m is the number
of tokens.
**Parameters:**
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KV cache configuration parameters. Must have cache\_strategy
set to PAGED and page\_size must be defined.
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Input tensor with interleaved K/V data, shape (2\*m, hidden\_size) where
m is the number of tokens. Rows \[0, m) are keys, rows \[m, 2m) are values.
* kv\_collection (PagedCacheValues) – The paged KV cache collection containing cache blocks,
cache lengths, lookup tables, and max lengths tensors.
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Ragged tensor offsets indicating where each batch starts and ends
* lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – End index of LoRA token portion. Marks the boundary between
LoRA sequences and base model sequences in the batch.
* batch\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Total sequence length in the batch. Used for indexing
into the value portion of a.
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The transformer layer index to update in the KV cache.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a does not have rank 2.
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input\_row\_offsets does not have rank 1.
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If kv\_params.cache\_strategy is not PAGED.
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If kv\_params.page\_size is None.
**Return type:**
None
## `kv_cache_ragged_radd()` {#max.nn.legacy.kernels.kv_cache_ragged_radd}
> max.nn.legacy.kernels.kv\_cache\_ragged\_radd(kv\_params, a, kv\_collection, input\_row\_offsets, batch\_offset, layer\_idx)
This function adds a tensor to a slice of the KVCache, sliced on the batch dimension.
This expects that the requests which should be sliced out are contiguous and
in the front of the tensor, and we’re only adding to the last requests in the batch.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The tensor to add to the KVCache.
* kv\_collection (PagedCacheValues) – The KVCache collection to add to.
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The offsets of the input tensor.
* batch\_offset ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The batch to start applying the r-add to.
* layer\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – The layer index to add to.
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams))
**Return type:**
None
## `matmul_k_cache_ragged()` {#max.nn.legacy.kernels.matmul_k_cache_ragged}
> max.nn.legacy.kernels.matmul\_k\_cache\_ragged(kv\_params, hidden\_states, input\_row\_offsets, weight, kv\_collection, layer\_idx)
Computes key projections with ragged input.
hidden\_states and input\_row\_offsets are used together to
implement the ragged tensor.
input\_row\_offsets indicates where each batch starts and ends in input
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams object containing key-value cache parameters.
* hidden\_states ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input tensor with shape
\[M=total\_seq\_len, K=hidden\_dim].
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue indicating the start and end of each
batch in the input tensor with shape \[batch\_size + 1].
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight tensor with shape
\[N=num\_heads, K=hidden\_dim].
* input\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the input scale tensor with shape
\[ceildiv(K / BLOCK\_SIZE\_K), ceildiv(M / BLOCK\_SIZE\_M)].
* weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the weight scale tensor with
shape \[ceildiv(N / BLOCK\_SIZE\_N), ceildiv(K / BLOCK\_SIZE\_K)].
* kv\_collection (PagedCacheValues) – PagedCacheValues object for managing key-value cache.
* scales\_granularity\_mnk ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)]) – tuple\[int, int, int] representing the
scaling (BLOCK\_SIZE\_M, BLOCK\_SIZE\_N, BLOCK\_SIZE\_K).
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – TensorValue representing the layer index, expected to have
dtype uint32.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel,
or when the cache strategy is not supported.
**Return type:**
None
## `matmul_kv_cache_ragged()` {#max.nn.legacy.kernels.matmul_kv_cache_ragged}
> max.nn.legacy.kernels.matmul\_kv\_cache\_ragged(kv\_params, hidden\_states, input\_row\_offsets, weight, kv\_collection, layer\_idx)
Computes key and value projections with ragged input.
hidden\_states and input\_row\_offsets are used together to
implement the ragged tensor.
input\_row\_offsets indicates where each batch starts and ends in input
## `merge_ragged_tensors()` {#max.nn.legacy.kernels.merge_ragged_tensors}
> max.nn.legacy.kernels.merge\_ragged\_tensors(a, a\_row\_offsets, b, b\_row\_offsets)
Merges two ragged tensors into a single ragged tensor.
Both ragged tensors must have the same batch size (same number of row
offsets). This function interleaves the rows from each tensor based on
their row offsets.
**Parameters:**
* a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The first ragged tensor of shape \[total\_a\_rows, …].
* a\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The row offsets of the first ragged tensor,indicating
where each batch starts and ends in a.
* b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The second ragged tensor of shape \[total\_b\_rows, …].
* b\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The row offsets of the second ragged tensor, indicating
where each batch starts and ends in b.
**Returns:**
* The merged ragged tensor with shape
\[total\_a\_rows + total\_b\_rows, …].
* The merged row offsets with the same shape as input row offsets.
**Return type:**
A tuple of two tensors
Example:
```python
a = [1, 2, 3, 4, 5, 6]
a_row_offsets = [0, 2, 6]
b = [7, 8, 9, 10]
b_row_offsets = [0, 3, 4]
merged_tensor, merged_row_offsets = merge_ragged_tensors(
a, a_row_offsets, b, b_row_offsets)
merged_tensor = [1, 2, 7, 8, 9, 3, 4, 5, 6, 10]
merged_row_offsets = [0, 5, 10]
```
## `mla_decode_branch_fp8()` {#max.nn.legacy.kernels.mla_decode_branch_fp8}
> max.nn.legacy.kernels.mla\_decode\_branch\_fp8(q, input\_row\_offsets, freqs\_cis, kv\_a\_proj\_layernorm, w\_uk, w\_uk\_scale, w\_uv, w\_uv\_scale, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim, float8\_config)
This is a manually fused kernel that performs the following operations:
* Apply RoPE to the query and the key cache (in-place).
* Apply RMSNorm to the non-rope portion of the key cache (in-place).
* Project q\_nope to kv\_latent\_dim through a fp8 batched matmul:
q\_nope\_proj = q\_nope\_t @ w\_uk
* Concatenate q\_nope\_proj and q\_rope:
q\_full = concat(q\_nope\_proj, q\_rope, axis=2)
* Perform MLA decode
* Project raw\_output to v\_head\_dim through another fp8 batched matmul:
output = raw\_output\_t @ w\_uv
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Combined query tensor containing both nope and rope parts. Shape:
\[tot\_seq\_len, num\_heads, qk\_nope\_head\_dim + qk\_rope\_head\_dim].
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each request starts and ends in
input. This is a 1D tensor of shape \[num\_batches + 1].
* freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Precomputed RoPE frequency values for rotary position
embeddings. Shape: [max\_seq\_len, qk\_rope\_head\_dim].
* kv\_a\_proj\_layernorm ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RMSNorm gamma weights for normalizing the KV cache.
Shape: [kv\_lora\_rank].
* w\_uk ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight matrix for projecting q\_nope to kv\_latent\_dim. Shape:
\[num\_heads, kv\_latent\_dim, qk\_nope\_head\_dim].
* w\_uk\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scale for the weight matrix. Shape varies depending on
the float8\_config.
* w\_uv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight matrix for projecting MLA decode output to v\_head\_dim.
Shape: [num\_heads, v\_head\_dim, kv\_latent\_dim].
* w\_uv\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scale for the weight matrix. Shape varies depending on
the float8\_config.
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams
* kv\_collection (PagedCacheValues) – Paged KV Cache object.
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index.
* mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Mask variant.
* scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scale for the attention calculation.
* epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – Small constant for numerical stability in RMSNorm.
* v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension of the V heads.
* float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config)) – Float8Config for the weight matrix.
## `mla_prefill_branch_fp8()` {#max.nn.legacy.kernels.mla_prefill_branch_fp8}
> max.nn.legacy.kernels.mla\_prefill\_branch\_fp8(q, input\_row\_offsets, freqs\_cis, kv\_a\_proj\_layernorm, buffer\_row\_offsets, cache\_offsets, buffer\_length, kv\_b\_proj, kv\_b\_proj\_scale, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim, float8\_config)
This is a manually fused kernel that performs the following operations:
* Apply RoPE to the query and the key cache (in-place).
* Apply RMSNorm to the non-rope portion of the key cache (in-place).
* Copy the KV latent values from PagedKVCache to a contiguous buffer.
* Quantize the KV latent values to fp8.
* Up-project the latent KV values to full K and V through a matmul.
* Split the concatenated KV into K and V.
* Perform MLA prefill.
**Parameters:**
* q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Combined query tensor containing both nope and rope parts. Shape:
\[tot\_seq\_len, num\_heads, qk\_nope\_head\_dim + qk\_rope\_head\_dim].
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each request starts and ends in
input. This is a 1D tensor of shape \[num\_batches + 1].
* freqs\_cis ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Precomputed RoPE frequency values for rotary position
embeddings. Shape: [max\_seq\_len, qk\_rope\_head\_dim].
* kv\_a\_proj\_layernorm ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – RMSNorm gamma weights for normalizing the KV cache.
Shape: [kv\_lora\_rank].
* buffer\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates where each request’s KV latent values
should be stored in the contiguous buffer. This is a 1D tensor of
shape \[num\_batches + 1].
* cache\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Indicates the starting token position in the KV cache
from which to copy KV latent values for each request. This is a 1D
tensor of shape \[num\_batches + 1].
* buffer\_length ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The total number of tokens in the KV cache. Scalar.
* kv\_b\_proj ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Weight matrix for up-projecting the KV latent values to full
K and V. Shape: [num\_heads \* (qk\_nope\_head\_dim + v\_head\_dim),
kv\_latent\_dim].
* kv\_b\_proj\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scale for the weight matrix. Shape varies
depending on the float8\_config.
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – KVCacheParams
* kv\_collection (PagedCacheValues) – Paged KV Cache object.
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Layer index.
* mask\_variant ([MHAMaskVariant](attention/mask_config.md#max.nn.legacy.attention.mask_config.MHAMaskVariant)) – Mask variant.
* scale ([float](https://docs.python.org/3/library/functions.html#float)) – Scale for the attention calculation.
* epsilon ([float](https://docs.python.org/3/library/functions.html#float)) – Small constant for numerical stability in RMSNorm.
* v\_head\_dim ([int](https://docs.python.org/3/library/functions.html#int)) – Dimension of the V heads.
* float8\_config ([Float8Config](float8_config.md#max.nn.legacy.float8_config.Float8Config)) – Float8Config for the weight matrix.
## `mla_prefill_decode_graph_bf16()` {#max.nn.legacy.kernels.mla_prefill_decode_graph_bf16}
> max.nn.legacy.kernels.mla\_prefill\_decode\_graph\_bf16(q, input\_row\_offsets, freqs\_cis, kv\_norm\_gamma, buffer\_row\_offsets, cache\_offsets, buffer\_length, kv\_b\_proj, w\_uk, w\_uv, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim)
BF16 mega-kernel for MLA prefill/decode.
Switches between prefill and decode based on the maximum sequence length in
the batch.
## `mla_prefill_decode_graph_fp8()` {#max.nn.legacy.kernels.mla_prefill_decode_graph_fp8}
> max.nn.legacy.kernels.mla\_prefill\_decode\_graph\_fp8(q, input\_row\_offsets, freqs\_cis, kv\_a\_proj\_layernorm, buffer\_row\_offsets, cache\_offsets, buffer\_length, kv\_b\_proj, kv\_b\_proj\_scale, w\_uk, w\_uk\_scale, w\_uv, w\_uv\_scale, kv\_params, kv\_collection, layer\_idx, mask\_variant, scale, epsilon, v\_head\_dim, float8\_config)
Fused MLA prefill/decode kernel for FP8.
Switches between prefill and decode based on the maximum sequence length in
the batch. See mla\_prefill\_branch\_fp8 and mla\_decode\_branch\_fp8 for the
dedicated paths.
Output tensor of shape \[total\_seq\_len, num\_heads, v\_head\_dim].
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If input ranks/dtypes or cache strategy are invalid.
* [AssertionError](https://docs.python.org/3/library/exceptions.html#AssertionError) – If float8 scale block sizes are not set.
## `moe_create_indices()` {#max.nn.legacy.kernels.moe_create_indices}
> max.nn.legacy.kernels.moe\_create\_indices(topk\_ids, num\_local\_experts)
Creates indices for the MoE layer.
**Parameters:**
* topk\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The expert assignments for each token from the router.
* num\_local\_experts ([int](https://docs.python.org/3/library/functions.html#int)) – The number of experts on this device.
**Returns:**
* token\_expert\_order: The reordered token indices, grouped by assigned expert.
* expert\_start\_indices: The starting index for each expert’s token group in
the reordered sequence.
* restore\_token\_order: The indices to restore original token ordering after
expert computation.
* expert\_ids: ids of active experts selected for tokens
* expert\_usage\_stats: The maximum number of tokens assigned to any expert,
and the number of active experts.
* expert\_scores ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The scores for each expert for each token. Shape:
\[num\_tokens, n\_routed\_experts].
* expert\_bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The bias for each expert. Shape: [n\_routed\_experts].
* n\_routed\_experts ([int](https://docs.python.org/3/library/functions.html#int)) – The total number of experts. Must be divisible by
n\_groups.
* n\_experts\_per\_tok ([int](https://docs.python.org/3/library/functions.html#int)) – The number of experts to be selected per token.
* n\_groups ([int](https://docs.python.org/3/library/functions.html#int)) – The total number of expert groups. Must be divisible by
n\_routed\_experts.
* topk\_group ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum number of expert groups that a token will be
routed to.
* norm\_weights ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to normalize the selected expert weights.
* routed\_scaling\_factor ([float](https://docs.python.org/3/library/functions.html#float)) – The scaling factor for the routed expert weights.
**Returns:**
* expert\_indices: The indices of the routed experts for each token.
Shape: [num\_tokens, n\_experts\_per\_tok].
* expert\_weights: The weights of the routed experts for each token.
Shape: [num\_tokens, n\_experts\_per\_tok].
**Return type:**
A tuple of two tensors
## `needs_fp8_fnuz_conversion()` {#max.nn.legacy.kernels.needs_fp8_fnuz_conversion}
> max.nn.legacy.kernels.needs\_fp8\_fnuz\_conversion()
Check if we need to convert FP8 E4M3FN to FNUZ for AMD GPUs.
**Returns:**
True if running on AMD GPU with CDNA3 architecture, False otherwise.
## `normalize_e4m3fn_to_e4m3fnuz()` {#max.nn.legacy.kernels.normalize_e4m3fn_to_e4m3fnuz}
> max.nn.legacy.kernels.normalize\_e4m3fn\_to\_e4m3fnuz(weight, weight\_scale)
Convert E4M3FN weights to E4M3FNUZ format for AMD GPUs.
This conversion is necessary because AMD GPUs use the E4M3FNUZ format
while NVIDIA GPUs use E4M3FN. The key differences are:
1. The bit pattern 10000000 (-128) represents zero in E4M3FN but NaN in E4M3FNUZ
2. For the same bit representation, E4M3FNUZ values are half of E4M3FN values
**Parameters:**
* weight ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight tensor in E4M3FN format.
* weight\_scale ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The weight scale factor.
**Returns:**
Tuple of (converted\_weight, adjusted\_weight\_scale, adjusted\_input\_scale).
## `quantize_dynamic_block_scaled_fp4()` {#max.nn.legacy.kernels.quantize_dynamic_block_scaled_fp4}
> max.nn.legacy.kernels.quantize\_dynamic\_block\_scaled\_fp4(input, tensor\_sf, sf\_vector\_size=16, scales\_type=float8\_e4m3fn, out\_type=uint8)
Dynamically quantize the input tensor to fp4-e2m1fn.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to quantize. Shape: [seq\_len, hidden\_size]
* tensor\_sf ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [float](https://docs.python.org/3/library/functions.html#float)) – The tensor-wise scale factor (inverted as per quantization kernel requirement).
* sf\_vector\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The block size for the scaling factors.
* out\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the output tensor.
* scales\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the scales tensor.
**Returns:**
The quantized tensor in \[seq\_len, hidden\_size // 2] layout and the scales in \[ceildiv(seq\_len, 128), ceildiv(hidden\_size, sf\_vector\_size \* 4), 32, 4, 4] layout.
## `quantize_dynamic_scaled_float8()` {#max.nn.legacy.kernels.quantize_dynamic_scaled_float8}
> max.nn.legacy.kernels.quantize\_dynamic\_scaled\_float8(input, input\_scale\_spec, weight\_scale\_spec, scale\_ub=1200.0, group\_size\_or\_per\_token=-1, out\_type=float8\_e4m3fn, scales\_type=bfloat16)
Dynamically quantize the input tensor to fp8.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor to quantize.
* scale\_ub ([float](https://docs.python.org/3/library/functions.html#float)) – The upper bound of the scale factor.
* group\_size\_or\_per\_token ([int](https://docs.python.org/3/library/functions.html#int)) – The group size for quantization. When set to -1,
the quantization is column-wise.
* out\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the output tensor.
* scales\_type ([DType](../../dtype.md#max.dtype.DType)) – The type of the scales tensor.
* input\_scale\_spec ([Float8InputScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8InputScaleSpec))
* weight\_scale\_spec ([Float8WeightScaleSpec](float8_config.md#max.nn.legacy.float8_config.Float8WeightScaleSpec))
## `rms_norm_key_cache()` {#max.nn.legacy.kernels.rms_norm_key_cache}
> max.nn.legacy.kernels.rms\_norm\_key\_cache(kv\_params, kv\_collection, gamma, epsilon, layer\_idx, total\_seq\_len, input\_row\_offsets, weight\_offset, rms\_norm\_cols=None, multiply\_before\_cast=True, per\_head\_norm=True)
This function applies RMSNorm to the \_new\_ entries in the KVCache.
When per\_head\_norm=True (default), RMSNorm is applied separately to each head.
In this mode, gamma should have size \[head\_dim] and normalization occurs
across the head\_dim dimensions within each head.
When per\_head\_norm=False, RMSNorm is applied per token across all heads.
In this mode, gamma should have size \[n\_kv\_heads \* head\_dim] and normalization
occurs across all dimensions for each token.
The size of the gamma tensor determines how many dimensions will be normalized.
If gamma’s size doesn’t match the expected size based on per\_head\_norm setting,
rms\_norm\_cols must be explicitly specified to confirm the intention to normalize
only a subset of dimensions.
Currently, the KVCacheT class itself isn’t aware of the new cache entries
until cache length increment, which happens after model forward.
So use input\_row\_offsets to do this bookkeeping.
## `scatter_nd_skip_oob_indices()` {#max.nn.legacy.kernels.scatter_nd_skip_oob_indices}
> max.nn.legacy.kernels.scatter\_nd\_skip\_oob\_indices(input, updates, indices)
Creates a new symbolic tensor where the updates are scattered into input at specified indices.
This differs from scatter\_nd in that it handles oob indices by skipping
the update for that index. Oob indices are those which fall outside of
the range \[-dim, dim).
**Parameters:**
* input (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – The input symbolic tensor to write elements to.
* updates (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – A symbolic tensor of elements to write to input.
* indices (Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/shape.md#max.graph.shape.Shape) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](../../driver.md#max.driver.DLPackArray)) – A tensor of indices specifying where to write updates.
Shape should be \[num\_updates, rank] for full indexing or
\[num\_updates, k] for partial indexing where k < rank.
**Returns:**
A new symbolic tensor representing the result of the scatter\_nd operation.
## `scatter_set_constant()` {#max.nn.legacy.kernels.scatter_set_constant}
> max.nn.legacy.kernels.scatter\_set\_constant(data, indices, fill\_val)
Scatters values into a tensor at specified indices.
## `sgmv_kernel()` {#max.nn.legacy.kernels.sgmv_kernel}
> max.nn.legacy.kernels.sgmv\_kernel(input, lora, lora\_ids, lora\_ranks, input\_row\_offsets, max\_lora\_seq\_len, lora\_end\_idx=None, bias=None)
Performs the SGMV kernel for LoRA. This is LoRA agnostic, meaning that
we can perform LoRA A or B from this kernel call.
:param input: The input tensor
:param lora: The LoRA tensor
:param lora\_ids: Ids of the LoRAs used for each sequence
:param lora\_ranks: The ranks of the LoRAs ihn the batch
:param input\_row\_offsets: The sequence offsets that use LoRA
:param max\_lora\_seq\_len: The maximum sequence length of any given LoRA in the batch
:param bias: The LoRA bias
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `sgmv_lora_kernel()` {#max.nn.legacy.kernels.sgmv_lora_kernel}
> max.nn.legacy.kernels.sgmv\_lora\_kernel(input, lora\_a, lora\_b, lora\_ids, lora\_ranks, grouped\_row\_offsets, lora\_end\_idx, max\_lora\_seq\_len, bias=None)
Computes the SGMV LoRA kernel for some number of LoRAs A and B given the input.
out = Wx + xAB
SGMV can be explained by two independent kernels:
: - shrink -> shrinks high-dimensional tensor to low-rank tensor
* expand -> expands low-rank tensor to high-dimensional tensor
where v = [0, …] and y = (some output tensor)
SGMV-shrink:
: v += xA
SGMV-expand:
: y += vB
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor
* lora\_a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA tensor for A
* lora\_b ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA tensor for B
* lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Ids of the LoRAs used for each sequence
* lora\_ranks ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The ranks of the LoRAs ihn the batch
* grouped\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The grouped sequence offsets that use LoRA
* max\_lora\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum sequence length of any given LoRA in the batch
* bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – The LoRA bias
* lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `sgmv_lora_qkv_shrink()` {#max.nn.legacy.kernels.sgmv_lora_qkv_shrink}
> max.nn.legacy.kernels.sgmv\_lora\_qkv\_shrink(input, lora\_a, lora\_ids, lora\_grouped\_offsets, lora\_end\_idx, max\_lora\_seq\_len, max\_rank)
LoRA shrink grouped matmul with planar Q/K/V output.
Performs the LoRA ‘shrink’ operation for routed tokens using SGMV (segmented
grouped matrix-vector multiplication). Computes \[M, K] @ \[G, 3\*rank, K]^T
per active LoRA adapter, then permutes the flat \[M, 3\*rank] result into a
planar layout \[3, M, rank] representing separate Q, K, V projections.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Routed activation matrix with shape (M, K), where M is the total
number of tokens and K is the hidden dimension.
* lora\_a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Shrink weights for all LoRA adapters, shape (G, 3\*rank, K) where
G is the number of adapters and rank is the LoRA rank.
* lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Expert/adapter indices for each active group, shape (num\_active,).
Values in range \[0, G). May use -1 to indicate inactive slots.
* lora\_grouped\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Inclusive prefix sums of tokens per active adapter,
shape (num\_active + 1,). Defines per-adapter \[start, end) ranges in
input. Must be non-decreasing with offsets\[0] == 0.
* max\_lora\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – Upper bound on tokens for any active adapter. Used for
kernel tuning and memory allocation.
* max\_rank ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum LoRA rank, determines output shape.
* lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue))
**Returns:**
Output tensor with planar Q/K/V layout, shape (3, M, max\_rank).
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `sgmv_qkv_lora_kernel()` {#max.nn.legacy.kernels.sgmv_qkv_lora_kernel}
> max.nn.legacy.kernels.sgmv\_qkv\_lora\_kernel(input, lora\_a, lora\_b\_q, lora\_b\_kv, lora\_ids, lora\_ranks, input\_row\_offsets, lora\_grouped\_offsets, lora\_end\_idx, batch\_seq\_len, lora\_ids\_kv, lora\_grouped\_offsets\_kv, kv\_collection, kv\_params, layer\_idx, max\_lora\_seq\_len, max\_rank, bias=None)
Computes the SGMV QKV LoRA kernel for Q, K, V projections with LoRA.
**Parameters:**
* input ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The input tensor.
* lora\_a ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA A tensor.
* lora\_b\_q ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA B tensor for Q projection.
* lora\_b\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The LoRA B tensor for K and V projections (stacked).
* lora\_ids ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – IDs of the LoRAs used for each sequence.
* lora\_ranks ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The ranks of the LoRAs in the batch.
* input\_row\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The sequence offsets that use LoRA.
* lora\_grouped\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Grouped offsets for LoRA sequences.
* lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – End index of LoRA tokens in the batch.
* batch\_seq\_len ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Total sequence length of the batch.
* lora\_ids\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – LoRA IDs for KV projections (with offset for V portion).
* lora\_grouped\_offsets\_kv ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Grouped offsets for KV LoRA sequences.
* kv\_collection (PagedCacheValues) – The KV cache.
* kv\_params ([KVCacheParams](kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParams)) – The KV params.
* layer\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The layer index to retrieve the KV cache.
* max\_lora\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum sequence length of any given LoRA in the batch.
* max\_rank ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum rank for the LoRAs.
* bias ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | None) – Optional LoRA bias.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `sleep()` {#max.nn.legacy.kernels.sleep}
> max.nn.legacy.kernels.sleep(duration\_sec, device\_ref)
Sleep for the given duration in seconds.
This kernel is supported on CPUs and GPUs. However, the timing may be completely
inaccurate on AMD GPUs due to limitation of current time.sleep(…) impl.
**Parameters:**
* duration\_sec ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue)) – The duration to sleep in seconds.
* device\_ref ([DeviceRef](../../graph/type.md#max.graph.type.DeviceRef))
**Return type:**
None
## `sliced_add()` {#max.nn.legacy.kernels.sliced_add}
> max.nn.legacy.kernels.sliced\_add(x, y, lora\_end\_idx)
Adds tensors x and y element-wise for rows < lora\_end\_idx, otherwise copies x.
This is used for LoRA where only some sequences have LoRA applied.
For rows in \[0, lora\_end\_idx): c = x + y
For rows in \[lora\_end\_idx, batch\_seq\_len): c = x
**Parameters:**
* x ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – First input tensor.
* y ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – Second input tensor.
* lora\_end\_idx ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – End index of LoRA token portion (rows to apply add).
## `spatial_merge()` {#max.nn.legacy.kernels.spatial_merge}
> max.nn.legacy.kernels.spatial\_merge(input, grid\_thw, hidden\_size, merge\_size)
Performs spatial merge operation on ragged input tensors.
This operation merges spatial dimensions of input patches according to
the grid dimensions specified in grid\_thw.
## `unfused_qkv_ragged_matmul_gguf_quantized()` {#max.nn.legacy.kernels.unfused_qkv_ragged_matmul_gguf_quantized}
> max.nn.legacy.kernels.unfused\_qkv\_ragged\_matmul\_gguf\_quantized(kv\_params, input, input\_row\_offsets, n\_heads, q\_weight, k\_weight, v\_weight, quantization\_encoding\_q, quantization\_encoding\_k, quantization\_encoding\_v, kv\_collection, layer\_idx)
Computes fused query, key, and value projections with ragged input and
quantized weight matrices. A quantization\_config must be provided.
input and input\_row\_offsets are used together to implement the ragged
tensor.
input\_row\_offsets indicates where each batch starts and ends in input
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel.
## `update_frequency_data()` {#max.nn.legacy.kernels.update_frequency_data}
> max.nn.legacy.kernels.update\_frequency\_data(frequency\_data, frequency\_offsets, tokens)
Updates the frequency data.
**Parameters:**
* frequency\_data ([BufferValue](../../graph/BufferValue.md#max.graph.BufferValue)) – 2d tensor of shape \[unique\_tokens, 2], where
the first column indicates the token id and the second column
indicates the frequency of the token.
* frequency\_offsets ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – 1d tensor of shape \[batch\_size + 1], indicating
start of each sequence’s data.
* tokens ([TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)) – The tokens to update the frequency data with.
**Return type:**
None
---
## cache_params
## `KVCacheParamInterface` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface}
> class max.nn.legacy.kv\_cache.cache\_params.KVCacheParamInterface(\*args, \*\*kwargs)
Interface for KV cache parameters.
### `bytes_per_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.bytes_per_block}
> property bytes\_per\_block: [int](https://docs.python.org/3/library/functions.html#int)
Number of bytes per cache block.
### `cache_strategy` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.cache_strategy}
> cache\_strategy: [KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)
### `data_parallel_degree` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.data_parallel_degree}
> data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int)
### `get_symbolic_inputs()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.get_symbolic_inputs}
> get\_symbolic\_inputs()
Returns the symbolic inputs for the KV cache.
**Return type:**
InputSymbolInterface
### `n_devices` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.n_devices}
> n\_devices: [int](https://docs.python.org/3/library/functions.html#int)
### `page_size` {#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface.page_size}
> page\_size: [int](https://docs.python.org/3/library/functions.html#int)
## `KVCacheParams` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams}
> class max.nn.legacy.kv\_cache.cache\_params.KVCacheParams(dtype, n\_kv\_heads, head\_dim, num\_layers, devices, enable\_prefix\_caching=False, enable\_kvcache\_swapping\_to\_host=False, host\_kvcache\_swap\_space\_gb=None, cache\_strategy=KVCacheStrategy.PAGED, page\_size=128, is\_mla=False, data\_parallel\_degree=1, n\_kv\_heads\_per\_device=0, kvcache\_quant\_config=None)
Configuration parameters for key-value cache management in transformer models.
This class encapsulates all configuration options for managing KV caches during
inference, including parallelism settings, memory management, and cache strategy.
### `bytes_per_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.bytes_per_block}
> property bytes\_per\_block: [int](https://docs.python.org/3/library/functions.html#int)
Returns the number of bytes per cache block.
When TP>1, each block is sharded across the devices in the tensor parallel group.
This method returns the total memory needed to store a block across these devices.
Includes memory needed for scales if quantization is enabled.
**Returns:**
The number of bytes per cache block.
### `cache_strategy` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.cache_strategy}
> cache\_strategy: [KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy) = 'paged'
Strategy to use for managing the KV cache.
### `compute_num_host_blocks()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.compute_num_host_blocks}
> compute\_num\_host\_blocks()
Computes the number of blocks that can be allocated to the host.
**Returns:**
The number of blocks that can be allocated to the host.
### `copy_as_dp_1()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.copy_as_dp_1}
> copy\_as\_dp\_1()
Creates a copy of the KVCacheParams with data parallelism disabled.
This method creates a new instance of the current configuration and adjusts
the device count to reflect a tensor-parallel-only setup (data\_parallel\_degree=1).
The number of devices is divided by the current data parallel degree.
**Returns:**
A new KVCacheParams instance with data\_parallel\_degree set to 1.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If n\_devices is not evenly divisible by data\_parallel\_degree.
### `data_parallel_degree` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.data_parallel_degree}
> data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) = 1
Degree of data parallelism. Must be 1 or equal to n\_devices (DP+TP not yet supported).
### `devices` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.devices}
> devices: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[DeviceRef](../../../graph/type.md#max.graph.type.DeviceRef)]
Devices to use for the KV cache.
### `dtype` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.dtype}
> dtype: [DType](../../../dtype.md#max.dtype.DType)
Data type for storing key and value tensors in the cache.
### `dtype_shorthand` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.dtype_shorthand}
> property dtype\_shorthand: [str](https://docs.python.org/3/library/stdtypes.html#str)
Returns a shorthand textual representation of the data type.
**Returns:**
“bf16” for bfloat16 dtype, “f32” otherwise.
### `enable_kvcache_swapping_to_host` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.enable_kvcache_swapping_to_host}
> enable\_kvcache\_swapping\_to\_host: [bool](https://docs.python.org/3/library/functions.html#bool) = False
Whether to enable swapping of KV cache blocks to host memory when device memory is full.
### `enable_prefix_caching` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.enable_prefix_caching}
> enable\_prefix\_caching: [bool](https://docs.python.org/3/library/functions.html#bool) = False
Whether to enable prefix caching for efficient reuse of common prompt prefixes.
### `get_symbolic_inputs()` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.get_symbolic_inputs}
> get\_symbolic\_inputs()
Computes the symbolic inputs for the KV cache.
This method returns a list of PagedCacheInputSymbols for each replica.
This is used when constructing the model graph.
**Returns:**
The symbolic inputs for the KV cache.
**Return type:**
PagedCacheInputSymbolsByReplica
### `head_dim` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.head_dim}
> head\_dim: [int](https://docs.python.org/3/library/functions.html#int)
Dimensionality of each attention head.
### `host_kvcache_swap_space_gb` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.host_kvcache_swap_space_gb}
> host\_kvcache\_swap\_space\_gb: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None
Amount of host memory (in GB) to reserve for KV cache swapping. Required when swapping is enabled.
### `is_mla` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.is_mla}
> is\_mla: [bool](https://docs.python.org/3/library/functions.html#bool) = False
Whether the model uses Multi-Latent Attention (MLA) architecture.
### `kvcache_quant_config` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.kvcache_quant_config}
> kvcache\_quant\_config: [KVCacheQuantizationConfig](#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig) | [None](https://docs.python.org/3/library/constants.html#None) = None
KVCache quantization config. Currently only FP8 quantization supported.
### `n_devices` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.n_devices}
> property n\_devices: [int](https://docs.python.org/3/library/functions.html#int)
Returns the number of devices.
**Returns:**
The number of devices.
### `n_kv_heads` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.n_kv_heads}
> n\_kv\_heads: [int](https://docs.python.org/3/library/functions.html#int)
Total number of key-value attention heads across all devices.
### `n_kv_heads_per_device` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.n_kv_heads_per_device}
> n\_kv\_heads\_per\_device: [int](https://docs.python.org/3/library/functions.html#int) = 0
Number of KV heads allocated to each device. Computed automatically in \_\_post\_init\_\_.
### `num_layers` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.num_layers}
> num\_layers: [int](https://docs.python.org/3/library/functions.html#int)
Number of layers in the model.
### `page_size` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.page_size}
> page\_size: [int](https://docs.python.org/3/library/functions.html#int) = 128
Number of tokens per page (block) when using the paged cache strategy.
This value is expressed in tokens, not bytes. The byte footprint of a page is
derived from pipeline configuration.
Current constraints: the page size must be a multiple of 128 and at least 128.
Required when `cache_strategy` is `KVCacheStrategy.PAGED`.
### `quantized_kv_cache` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.quantized_kv_cache}
> property quantized\_kv\_cache: [bool](https://docs.python.org/3/library/functions.html#bool)
### `shape_per_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.shape_per_block}
> property shape\_per\_block: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Returns the shape of each cache block.
**Returns:**
The shape of the cache block.
### `shape_per_scale_block` {#max.nn.legacy.kv_cache.cache_params.KVCacheParams.shape_per_scale_block}
> property shape\_per\_scale\_block: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Returns the shape of each scale block used for KVCache quantization
**Returns:**
The shape of the KVCache quantization scales block.
## `KVCacheQuantizationConfig` {#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig}
> class max.nn.legacy.kv\_cache.cache\_params.KVCacheQuantizationConfig(scale\_dtype=float32, quantization\_granularity=128)
Configuration for KVCache quantization.
Currently only FP8 Quantization is supported.
### `quantization_granularity` {#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig.quantization_granularity}
> quantization\_granularity: [int](https://docs.python.org/3/library/functions.html#int) = 128
Block-size used for KVCache quantization along head-dimension (e.g. 128).
### `scale_dtype` {#max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig.scale_dtype}
> scale\_dtype: [DType](../../../dtype.md#max.dtype.DType) = 81
Data type of quantization scales, if quantization is enabled
## `KVCacheStrategy` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy}
> class max.nn.legacy.kv\_cache.cache\_params.KVCacheStrategy(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
Enumeration of supported KV cache strategies for attention mechanisms.
This enum defines the different strategies for managing key-value caches
in transformer models during inference.
### `MODEL_DEFAULT` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy.MODEL_DEFAULT}
> MODEL\_DEFAULT = 'model\_default'
Use the model’s default caching strategy.
### `PAGED` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy.PAGED}
> PAGED = 'paged'
Use paged attention for efficient memory management.
### `kernel_substring()` {#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy.kernel_substring}
> kernel\_substring()
Returns the common substring included in the kernel name for this caching strategy.
**Returns:**
The string representation of the cache strategy value.
## `MultiKVCacheParams` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams}
> class max.nn.legacy.kv\_cache.cache\_params.MultiKVCacheParams(params, cache\_strategy, page\_size, data\_parallel\_degree, n\_devices)
Aggregates multiple KV cache parameter sets.
This class implements KVCacheParamInterface by aggregating multiple
KVCacheParamInterface instances. Useful for models with multiple distinct
KV caches (e.g., different cache configurations for different layers).
### `bytes_per_block` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.bytes_per_block}
> property bytes\_per\_block: [int](https://docs.python.org/3/library/functions.html#int)
Total bytes per block across all KV caches.
Since all caches allocate memory for the same sequence, the total
memory cost per block is the sum across all param sets.
### `cache_strategy` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.cache_strategy}
> cache\_strategy: [KVCacheStrategy](#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)
### `data_parallel_degree` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.data_parallel_degree}
> data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int)
### `from_params()` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.from_params}
> classmethod from\_params(\*params)
### `get_symbolic_inputs()` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.get_symbolic_inputs}
> get\_symbolic\_inputs()
Returns the symbolic inputs for the KV cache.
**Return type:**
MultiKVCacheInputSymbols
### `n_devices` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.n_devices}
> n\_devices: [int](https://docs.python.org/3/library/functions.html#int)
### `page_size` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.page_size}
> page\_size: [int](https://docs.python.org/3/library/functions.html#int)
### `params` {#max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams.params}
> params: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)]
List of KV cache parameter sets to aggregate.
## `compute_max_seq_len_fitting_in_cache()` {#max.nn.legacy.kv_cache.cache_params.compute_max_seq_len_fitting_in_cache}
> max.nn.legacy.kv\_cache.cache\_params.compute\_max\_seq\_len\_fitting\_in\_cache(params, available\_cache\_memory)
Computes the maximum sequence length that can fit in the available memory.
**Parameters:**
* available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – The amount of cache memory available across
* devices. (all)
* params ([KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface))
**Returns:**
The maximum sequence length that can fit in the available cache memory.
## `compute_num_device_blocks()` {#max.nn.legacy.kv_cache.cache_params.compute_num_device_blocks}
> max.nn.legacy.kv\_cache.cache\_params.compute\_num\_device\_blocks(params, available\_cache\_memory, max\_batch\_size, max\_seq\_len)
Computes the number of blocks that can be allocated based on the available cache memory.
The number of blocks returned is for a single replica. Each replica will
have the same number of blocks.
**Parameters:**
* available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – The amount of cache memory available across all devices.
* max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int) | None) – The maximum batch size, or None.
* max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int) | None) – The maximum sequence length, or None.
* params ([KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface))
**Returns:**
The number of blocks that can be allocated for a single replica.
## `estimated_memory_size()` {#max.nn.legacy.kv_cache.cache_params.estimated_memory_size}
> max.nn.legacy.kv\_cache.cache\_params.estimated\_memory\_size(params, available\_cache\_memory, max\_batch\_size, max\_seq\_len)
Computes the estimated memory size of the KV cache used by all replicas.
**Parameters:**
* available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – The amount of cache memory available across all devices.
* max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum batch size.
* max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – The maximum sequence length.
* params ([KVCacheParamInterface](#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface))
**Returns:**
The estimated memory usage of the KV cache in bytes.
---
## kv_cache (Kv_cache)
Legacy key-value cache management for efficient attention computation.
## Modules
* [`cache_params`](/max/api/python/nn/legacy/kv_cache/cache_params): Configuration parameters for KV cache.
* [`manager`](/max/api/python/nn/legacy/kv_cache/manager): KV cache manager implementation.
---
## manager
---
## layer
## `Layer` {#max.nn.legacy.layer.Layer}
> class max.nn.legacy.layer.Layer
:::caution Deprecated
Deprecated since version 25.2..
:::
Base class for neural network components.
Use [`Module`](#max.nn.legacy.layer.Module) instead.
Provides functionality for adding hooks to the call function of
each layer to support testing, debugging or profiling.
## `LayerList` {#max.nn.legacy.layer.LayerList}
> class max.nn.legacy.layer.LayerList(layers)
Stores a list of layers.
Can be used as a regular python list.
* i ([int](https://docs.python.org/3/library/functions.html#int))
* layer ([Layer](#max.nn.legacy.layer.Layer))
**Return type:**
None
### `sublayers` {#max.nn.legacy.layer.LayerList.sublayers}
> property sublayers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.legacy.layer.Module)]
## `Module` {#max.nn.legacy.layer.Module}
> class max.nn.legacy.layer.Module
Base class for model components with weight management.
Provides functionality to create custom layers and construct networks with automatic weight tracking.
The following example uses the [`Module`](#max.nn.legacy.layer.Module) class to create custom layers and build a neural network:
```python
from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef
class Linear(nn.Module):
def __init__(self, in_dims, out_dims):
super().__init__()
self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())
def __call__(self, x):
return x @ self.weight.T
class MLP(nn.Module):
def __init__(self):
self.up = Linear(5, 10)
self.gate = Linear(5, 10)
self.down = Linear(10, 5)
def __call__(self, x):
return self.down(ops.silu(self.gate(x)) + self.up(x))
model = MLP()
print(model.state_dict()) # {"up.weight": Buffer([5, 10]), ...}
```
Constructing a graph without [`Module`](#max.nn.legacy.layer.Module) can result in name collisions
with the weights (in this example, there would be three weights with the
name Weight). With [`Module`](#max.nn.legacy.layer.Module), you can use [`state_dict()`](#max.nn.legacy.layer.Module.state_dict) or
[`load_state_dict()`](#max.nn.legacy.layer.Module.load_state_dict) to initialize or set the weights values, and finalize
the weight names to be unique within the model.
### `build_subgraph()` {#max.nn.legacy.layer.Module.build_subgraph}
> build\_subgraph(name, input\_types, weight\_prefix='')
Builds a subgraph for this module.
This method creates a subgraph that encapsulates the module’s logic,
handling input types, weights, and creating a graph with the module’s
computation.
Once the subgraph is built, it can be called using the `ops.call`
op.
**Parameters:**
* name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name of the subgraph to create.
* input\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Type](../../graph/type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Type](../../graph/type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – A list of input types for the subgraph. Each element can be
either a single `Type` or a list of `Type` objects.
* weight\_prefix ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Optional prefix for weight names in the subgraph. If provided,
weights with names starting with this prefix will have their names
modified by removing the prefix and will be marked as placeholders.
**Returns:**
The created subgraph containing the module’s computation.
**Return type:**
`Graph`
:::note Note
* Placeholder weights will require the `prefix` attribute of `ops.call` to be set.
:::
### `layer_weights` {#max.nn.legacy.layer.Module.layer_weights}
> property layer\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Weight](../../graph/Weight.md#max.graph.Weight)]
### `load_state_dict()` {#max.nn.legacy.layer.Module.load_state_dict}
> load\_state\_dict(state\_dict, \*, override\_quantization\_encoding=False, weight\_alignment=None, strict=True)
Sets the values of all weights in this model.
**Parameters:**
* state\_dict ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../../driver.md#max.driver.DLPackArray) | [WeightData](../../graph/weights.md#max.graph.weights.WeightData)]) – A map from weight name to a numpy array or
[`max.driver.Buffer`](../../driver.md#max.driver.Buffer).
* override\_quantization\_encoding ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to override the weight
quantization based on the loaded value.
* weight\_alignment ([int](https://docs.python.org/3/library/functions.html#int) | None) – If specified, overrides the alignment for each
weight in the Module. If left as None, each value in
state\_dict must be aligned by the default dtype alignment.
* strict ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True, raises an error if any weights required by the
Module are missing from state\_dict, or if any keys in
state\_dict were not used by the Module. If False, both
missing and unexpected keys are tolerated and reported only
via return values/logging by callers.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If strict is True and any required weight is missing
from state\_dict, or if state\_dict contains keys not used by
the Module.
**Return type:**
None
### `raw_state_dict()` {#max.nn.legacy.layer.Module.raw_state_dict}
> raw\_state\_dict()
Returns all weights objects in the model.
Unlike [`state_dict`](#max.nn.legacy.layer.Module.state_dict), this returns [`max.graph.Weight`](../../graph/Weight.md#max.graph.Weight) objects instead of
the assigned values. Some parameters inside the `Weight` can be
configured before a graph is built. Do not change these attributes after
building a graph:
* [`align`](../../graph/Weight.md#max.graph.Weight.align)
* [`dtype`](../../graph/Weight.md#max.graph.Weight.dtype)
* [`quantization_encoding`](../../graph/Weight.md#max.graph.Weight.quantization_encoding)
* [`shape`](../../graph/Weight.md#max.graph.Weight.shape)
**Returns:**
Map from weight name to the [`max.graph.Weight`](../../graph/Weight.md#max.graph.Weight) object.
* name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* weight ([Weight](../../graph/Weight.md#max.graph.Weight))
**Return type:**
None
### `state_dict()` {#max.nn.legacy.layer.Module.state_dict}
> state\_dict(auto\_initialize=True)
Returns values of all weights in the model.
The values returned are the same as the values set in [`load_state_dict`](#max.nn.legacy.layer.Module.load_state_dict).
If [`load_state_dict`](#max.nn.legacy.layer.Module.load_state_dict) has not been called and none of the weights have
values, then they are initialized to zero.
**Parameters:**
auto\_initialize ([bool](https://docs.python.org/3/library/functions.html#bool)) – Determines whether to initialize weights to zero if
the weight value has not been loaded. If this is False, a
ValueError is raised if an uninitialized weight is found.
**Returns:**
Map from weight name to the weight value (can be numpy array or
[`max.driver.Buffer`](../../driver.md#max.driver.Buffer)).
### `sublayers` {#max.nn.legacy.layer.Module.sublayers}
> property sublayers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.legacy.layer.Module)]
## `Shardable` {#max.nn.legacy.layer.Shardable}
> class max.nn.legacy.layer.Shardable(\*args, \*\*kwargs)
Protocol for objects that support sharding across multiple devices.
This protocol defines the interface that all shardable components
(like Linear layers and Weight objects) must implement to participate
in distributed computation.
### `shard()` {#max.nn.legacy.layer.Shardable.shard}
> shard(devices)
Creates a sharded view of this object for a specific device.
**Parameters:**
* device – The devices where this shard should reside.
* devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)])
### `sharding_strategy` {#max.nn.legacy.layer.Shardable.sharding_strategy}
> property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None)
Gets the weight sharding strategy.
## `add_layer_hook()` {#max.nn.legacy.layer.add_layer_hook}
> max.nn.legacy.layer.add\_layer\_hook(fn)
Adds a hook to call a function after each layer’s `__call__`.
The function will be passed four inputs:
* layer
* input\_args
* input\_kwargs
* outputs
The function can either return None or new
outputs that will replace the layer returned outputs.
Note that input and outputs contain graph Values, which show limited
information (like [`shape`](../../graph/TensorValue.md#max.graph.TensorValue.shape) and [`dtype`](../../graph/TensorValue.md#max.graph.TensorValue.dtype)). You can still see the computed values
if you include the Value in the `graph.ops.output` op, or call `graph.ops.print`.
Example of printing debug inputs:
```python
def print_info(layer, args, kwargs, outputs):
print("Layer:", type(layer).__name__)
print("Input args:", args)
print("Input kwargs:", kwargs)
print("Outputs:", outputs)
return outputs
add_layer_hook(print_info)
```
## `clear_hooks()` {#max.nn.legacy.layer.clear_hooks}
> max.nn.legacy.layer.clear\_hooks()
Remove all hooks.
**Return type:**
None
## `recursive_named_layers()` {#max.nn.legacy.layer.recursive_named_layers}
> max.nn.legacy.layer.recursive\_named\_layers(parent, prefix='')
Recursively walks through the layers and generates names.
## `DistributedGemmConfig` {#max.nn.legacy.linear.DistributedGemmConfig}
> class max.nn.legacy.linear.DistributedGemmConfig(enable\_matmul\_allreduce)
Configure how distributed GEMM is executed.
Configuration for distributed General Matrix Multiply operations.
## `Linear` {#max.nn.legacy.linear.Linear}
> class max.nn.legacy.linear.Linear(in\_dim, out\_dim, dtype, device, has\_bias=False, quantization\_encoding=None, float8\_config=None, name=None, clip\_weight=None, is\_sharding=False)
Applies a linear transformation to incoming data: $y = xW^T + b$.
This layer implements a fully connected layer where inputs are multiplied
by a weight matrix and optionally added with a bias vector.
Both weights and bias initially reside on CPU, and the model init phase
moves them to the specified device.
Example:
```python
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)
```
### `bias` {#max.nn.legacy.linear.Linear.bias}
> bias: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None
The optional bias vector stored on CPU with shape (out\_dim,).
Model init moves the bias to the target device if present.
### `device` {#max.nn.legacy.linear.Linear.device}
> device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)
The device where matrix operations are performed.
### `input_scale` {#max.nn.legacy.linear.Linear.input_scale}
> input\_scale: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None
The optional input scale stored on CPU with shape ().
Model init moves the input\_scale to the target device if present.
### `shard()` {#max.nn.legacy.linear.Linear.shard}
> shard(devices)
Creates sharded views of this Linear layer across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of `DeviceRef` devices to place the shards on.
**Returns:**
List of sharded [`Linear`](#max.nn.legacy.linear.Linear) instances, one for each device.
### `sharding_strategy` {#max.nn.legacy.linear.Linear.sharding_strategy}
> property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None)
Get the weight sharding strategy.
### `weight` {#max.nn.legacy.linear.Linear.weight}
> weight: [Weight](../../graph/Weight.md#max.graph.Weight)
The weight matrix stored on CPU with shape (out\_dim, in\_dim).
Model init transposes the weight and moves it to the target device.
### `weight_scale` {#max.nn.legacy.linear.Linear.weight_scale}
> weight\_scale: [Weight](../../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None
The optional weight scale stored on CPU with shape () or (N,).
Model init moves the weight\_scale to the target device if present.
## `MLP` {#max.nn.legacy.linear.MLP}
> class max.nn.legacy.linear.MLP(dtype, quantization\_encoding, hidden\_dim, feed\_forward\_length, devices, linear\_cls=\, has\_bias=False, activation\_function='silu', float8\_config=None, dist\_gemm\_config=None, is\_sharding=False)
Simple multi-layer perceptron composed of three [`Linear`](#max.nn.legacy.linear.Linear) layers.
Defaults to SiLU activation function.
### `shard()` {#max.nn.legacy.linear.MLP.shard}
> shard(devices)
Creates sharded views of this MLP across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded MLP instances, one for each device.
### `down_proj` {#max.nn.legacy.moe.MoE.down_proj}
> property down\_proj: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `ep_batch_manager` {#max.nn.legacy.moe.MoE.ep_batch_manager}
> property ep\_batch\_manager: EPBatchManager
Get the expert parallel batch manager.
### `experts` {#max.nn.legacy.moe.MoE.experts}
> experts: [LayerList](layer.md#max.nn.legacy.layer.LayerList)
The list of experts.
### `gate_up_proj` {#max.nn.legacy.moe.MoE.gate_up_proj}
> property gate\_up\_proj: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `shard()` {#max.nn.legacy.moe.MoE.shard}
> shard(devices)
Create sharded views of this MoE module across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded MoE instances, one for each device.
### `shard_devices` {#max.nn.legacy.moe.MoE.shard_devices}
> shard\_devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)] = []
The list of devices the MoE layer was sharded to.
### `shard_index` {#max.nn.legacy.moe.MoE.shard_index}
> shard\_index: [int](https://docs.python.org/3/library/functions.html#int) = 0
The index of the current shard (if the MoE layer was sharded).
### `sharding_strategy` {#max.nn.legacy.moe.MoE.sharding_strategy}
> property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None)
Get the sharding strategy for the module.
## `MoEGate` {#max.nn.legacy.moe.MoEGate}
> class max.nn.legacy.moe.MoEGate(devices, hidden\_dim, num\_experts, num\_experts\_per\_token, dtype, is\_sharding=False, linear\_cls=\)
Gate module for MoE.
### `shard()` {#max.nn.legacy.moe.MoEGate.shard}
> shard(devices)
Create sharded views of this MoEGate module across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded MoEGate instances, one for each device.
### `fused_silu_quantize()` {#max.nn.legacy.moe.Nvfp4Strategy.fused_silu_quantize}
> fused\_silu\_quantize(gate\_up\_projs, input\_scales=None, expert\_inputs=())
Applies SiLU gate then NVFP4 quantizes the result.
## `QuantStrategy` {#max.nn.legacy.moe.QuantStrategy}
> class max.nn.legacy.moe.QuantStrategy(\*args, \*\*kwargs)
Quantization strategy for MoE layers.
### `fused_silu_quantize()` {#max.nn.legacy.moe.QuantStrategy.fused_silu_quantize}
> fused\_silu\_quantize(gate\_up\_projs, input\_scales=None, expert\_inputs=())
Applies gating and quantizes activations for the down proj.
### `beta` {#max.nn.legacy.norm.ConstantLayerNorm.beta}
> beta: npt.NDArray\[np.floating\[Any]]
### `device` {#max.nn.legacy.norm.ConstantLayerNorm.device}
> device: [DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef)
### `dtype` {#max.nn.legacy.norm.ConstantLayerNorm.dtype}
> dtype: [DType](../../dtype.md#max.dtype.DType)
### `eps` {#max.nn.legacy.norm.ConstantLayerNorm.eps}
> eps: [float](https://docs.python.org/3/library/functions.html#float) = 1e-05
### `gamma` {#max.nn.legacy.norm.ConstantLayerNorm.gamma}
> gamma: npt.NDArray\[np.floating\[Any]]
## `GroupNorm` {#max.nn.legacy.norm.GroupNorm}
> class max.nn.legacy.norm.GroupNorm(num\_groups, num\_channels, eps=1e-05, affine=True, device=gpu:0)
Group normalization block.
Divides channels into groups and computes normalization stats per group.
Follows the implementation pattern from PyTorch’s group\_norm.
**Parameters:**
* num\_groups ([int](https://docs.python.org/3/library/functions.html#int)) – Number of groups to separate the channels into
* num\_channels ([int](https://docs.python.org/3/library/functions.html#int)) – Number of input channels
* eps ([float](https://docs.python.org/3/library/functions.html#float)) – Small constant added to denominator for numerical stability
* affine ([bool](https://docs.python.org/3/library/functions.html#bool)) – If True, apply learnable affine transform parameters
* device ([DeviceRef](../../graph/ops.md#max.graph.ops.DeviceRef))
### `shard()` {#max.nn.legacy.norm.LayerNorm.shard}
> shard(devices)
Creates sharded views of this LayerNorm across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded LayerNorm instances, one for each device.
### `sharding_strategy` {#max.nn.legacy.norm.LayerNorm.sharding_strategy}
> property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None)
Get the LayerNorm sharding strategy.
## `RMSNorm` {#max.nn.legacy.norm.RMSNorm}
> class max.nn.legacy.norm.RMSNorm(dim, dtype, eps=1e-06, weight\_offset=0.0, multiply\_before\_cast=True)
Computes the Root Mean Square normalization on inputs.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int)) – Size of last dimension of the expected input.
* eps ([float](https://docs.python.org/3/library/functions.html#float)) – Value added to denominator for numerical stability.
* weight\_offset ([float](https://docs.python.org/3/library/functions.html#float)) – Constant offset added to the learned weights at runtime.
For Gemma-style RMSNorm, this should be set to 1.0.
* multiply\_before\_cast ([bool](https://docs.python.org/3/library/functions.html#bool)) – True if we multiply the inputs by the learned
weights before casting to the input type (Gemma3-style). False if we
cast the inputs to the input type first, then multiply by the learned
weights (Llama-style).
* dtype ([DType](../../dtype.md#max.dtype.DType))
### `shard()` {#max.nn.legacy.norm.RMSNorm.shard}
> shard(devices)
Creates sharded views of this RMSNorm across multiple devices.
**Parameters:**
devices ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[DeviceRef](../../graph/type.md#max.graph.type.DeviceRef)]) – Iterable of devices to place the shards on.
**Returns:**
List of sharded RMSNorm instances, one for each device.
### `beta_fast` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.beta_fast}
> beta\_fast: [int](https://docs.python.org/3/library/functions.html#int)
Fast interpolation rate.
### `beta_slow` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.beta_slow}
> beta\_slow: [int](https://docs.python.org/3/library/functions.html#int)
Slow interpolation rate.
### `mscale` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.mscale}
> mscale: [float](https://docs.python.org/3/library/functions.html#float)
Scaling factor for middle frequencies.
### `mscale_all_dim` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.mscale_all_dim}
> mscale\_all\_dim: [float](https://docs.python.org/3/library/functions.html#float)
Scaling factor applied to all dimensions.
### `original_max_position_embeddings` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.original_max_position_embeddings}
> original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int)
Original maximum sequence length during training.
### `scaling_factor` {#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams.scaling_factor}
> scaling\_factor: [float](https://docs.python.org/3/library/functions.html#float)
Scaling factor for frequency interpolation.
## `DeepseekYarnRotaryEmbedding` {#max.nn.legacy.rotary_embedding.DeepseekYarnRotaryEmbedding}
> class max.nn.legacy.rotary\_embedding.DeepseekYarnRotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None)
Deepseek’s YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer.
Unlike Llama3RotaryEmbedding, the dim argument here is the rope dimension
of the model, not the hidden dimension.
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.DeepseekYarnRotaryEmbedding.freqs_cis_base}
> freqs\_cis\_base()
Computes the frequency tensor for complex exponentials (cis)
for a given seq\_len. Tensor is scaled with theta parameter.
Required to apply Rotary Position Embedding (RoPE) to tensor.
See ‘Roformer: Enhanced Transformer with Rotary Embedding’
(arxiv.org/pdf/2104.09864).
**Returns:**
The frequency tensor for complex exponentials with shape
(max\_seq\_len, rope\_dim // 2, 2)
### `scaling_params` {#max.nn.legacy.rotary_embedding.DeepseekYarnRotaryEmbedding.scaling_params}
> scaling\_params: [DeepseekYarnRopeScalingParams](#max.nn.legacy.rotary_embedding.DeepseekYarnRopeScalingParams) | [None](https://docs.python.org/3/library/constants.html#None) = None
## `DynamicRotaryEmbedding` {#max.nn.legacy.rotary_embedding.DynamicRotaryEmbedding}
> class max.nn.legacy.rotary\_embedding.DynamicRotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True)
RotaryEmbedding with dynamic scaling support for long-context inference.
Dynamically updates the inv\_freq and corresponding freqs\_cis buffer if the
current sequence length exceeds the original max, or resets to the original
high-precision version for short sequences.
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.DynamicRotaryEmbedding.freqs_cis_base}
> freqs\_cis\_base()
Computes freqs\_cis dynamically using the current self.inv\_freq.
### `maybe_update_freqs()` {#max.nn.legacy.rotary_embedding.DynamicRotaryEmbedding.maybe_update_freqs}
> maybe\_update\_freqs(position\_ids)
Update freqs\_cis if the sequence exceeds max\_seq\_len\_cached, or revert
to the original version if back below the threshold.
### `factor` {#max.nn.legacy.rotary_embedding.LinearScalingParams.factor}
> factor: [float](https://docs.python.org/3/library/functions.html#float)
Main scaling factor for the frequency components of the rope.
## `Llama3RopeScalingParams` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams}
> class max.nn.legacy.rotary\_embedding.Llama3RopeScalingParams(factor: [float](https://docs.python.org/3/library/functions.html#float), low\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float), high\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float), orig\_max\_position: [int](https://docs.python.org/3/library/functions.html#int))
### `factor` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.factor}
> factor: [float](https://docs.python.org/3/library/functions.html#float)
Main scaling factor for the frequency components of the rope.
### `high_freq_factor` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.high_freq_factor}
> high\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float)
Factor to scale the high frequency components of the rope.
### `low_freq_factor` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.low_freq_factor}
> low\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float)
Factor to scale the low frequency components of the rope.
### `orig_max_position` {#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams.orig_max_position}
> orig\_max\_position: [int](https://docs.python.org/3/library/functions.html#int)
The original maximum position length supported by the model.
## `Llama3RotaryEmbedding` {#max.nn.legacy.rotary_embedding.Llama3RotaryEmbedding}
> class max.nn.legacy.rotary\_embedding.Llama3RotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None)
RotaryEmbedding for Llama3 that takes rope scaling into account.
### `scaling_params` {#max.nn.legacy.rotary_embedding.Llama3RotaryEmbedding.scaling_params}
> scaling\_params: [Llama3RopeScalingParams](#max.nn.legacy.rotary_embedding.Llama3RopeScalingParams) | [None](https://docs.python.org/3/library/constants.html#None) = None
Scaling parameters to enable llama to function with a longer context length.
## `LongRoPERotaryEmbedding` {#max.nn.legacy.rotary_embedding.LongRoPERotaryEmbedding}
> class max.nn.legacy.rotary\_embedding.LongRoPERotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None)
Rotary position embedding with LongRoPE scaling for Phi-3.5 models.
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.LongRoPERotaryEmbedding.freqs_cis_base}
> freqs\_cis\_base()
Computes the frequency tensor for complex exponentials (cis)
with LongRoPE scaling. Creates a “stitched” table where:
* Positions 0 to original\_max\_position use short\_factor
* Positions from original\_max\_position onwards use long\_factor
**Returns:**
The frequency tensor for complex exponentials with shape (max\_seq\_len \* 2, head\_dim / 2, 2)
## `LongRoPEScalingParams` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams}
> class max.nn.legacy.rotary\_embedding.LongRoPEScalingParams(short\_factor, long\_factor, original\_max\_position, max\_position\_embeddings)
Parameters for LongRoPE scaling as used in Phi-3.5 models.
### `long_factor` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.long_factor}
> long\_factor: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)]
Scaling factors for long sequences (can be much larger).
### `max_position_embeddings` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.max_position_embeddings}
> max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int)
Current max position embeddings after scaling.
### `original_max_position` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.original_max_position}
> original\_max\_position: [int](https://docs.python.org/3/library/functions.html#int)
Original max position embeddings the model was trained with.
### `short_factor` {#max.nn.legacy.rotary_embedding.LongRoPEScalingParams.short_factor}
> short\_factor: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)]
Scaling factors for short sequences (typically close to 1.0).
## `RotaryEmbedding` {#max.nn.legacy.rotary_embedding.RotaryEmbedding}
> class max.nn.legacy.rotary\_embedding.RotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True)
RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials.
### `dim` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.dim}
> dim: [int](https://docs.python.org/3/library/functions.html#int)
### `freqs_cis` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.freqs_cis}
> property freqs\_cis: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue)
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.freqs_cis_base}
> freqs\_cis\_base()
Computes the frequency tensor for complex exponentials (cis)
for a given seq\_len. Tensor is scaled with theta parameter.
Required to apply Rotary Position Embedding (RoPE) to tensor.
See ‘Roformer: Enhanced Transformer with Rotary Embedding’
(arxiv.org/pdf/2104.09864).
**Returns:**
The frequency tensor for complex exponentials with shape (max\_seq\_len \* 2, head\_dim / 2, 2)
### `head_dim` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.head_dim}
> head\_dim: [int](https://docs.python.org/3/library/functions.html#int)
head\_dim = dim // n\_heads if not specified in the config.
### `interleaved` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.interleaved}
> interleaved: [bool](https://docs.python.org/3/library/functions.html#bool) = True
### `max_seq_len` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.max_seq_len}
> max\_seq\_len: [int](https://docs.python.org/3/library/functions.html#int)
The maximum sequence length for model’s input.
### `n_heads` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.n_heads}
> n\_heads: [int](https://docs.python.org/3/library/functions.html#int)
### `theta` {#max.nn.legacy.rotary_embedding.RotaryEmbedding.theta}
> theta: [float](https://docs.python.org/3/library/functions.html#float)
Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings.
## `YarnRotaryEmbedding` {#max.nn.legacy.rotary_embedding.YarnRotaryEmbedding}
> class max.nn.legacy.rotary\_embedding.YarnRotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None)
Generic YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer.
This implementation provides YARN scaling for models that require it,
with configurable parameters for beta\_fast, beta\_slow, and scaling factor.
### `freqs_cis_base()` {#max.nn.legacy.rotary_embedding.YarnRotaryEmbedding.freqs_cis_base}
> freqs\_cis\_base()
Computes the frequency tensor for complex exponentials (cis)
with YARN scaling applied.
### `beta_fast` {#max.nn.legacy.rotary_embedding.YarnScalingParams.beta_fast}
> beta\_fast: [float](https://docs.python.org/3/library/functions.html#float)
Yarn parameter for fast frequencies.
### `beta_slow` {#max.nn.legacy.rotary_embedding.YarnScalingParams.beta_slow}
> beta\_slow: [float](https://docs.python.org/3/library/functions.html#float)
Yarn parameter for slow frequencies.
### `factor` {#max.nn.legacy.rotary_embedding.YarnScalingParams.factor}
> factor: [float](https://docs.python.org/3/library/functions.html#float)
Main scaling factor for the frequency components of the rope.
### `original_max_position_embeddings` {#max.nn.legacy.rotary_embedding.YarnScalingParams.original_max_position_embeddings}
> original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int)
The original maximum position length supported by the model.
### `truncate` {#max.nn.legacy.rotary_embedding.YarnScalingParams.truncate}
> truncate: [bool](https://docs.python.org/3/library/functions.html#bool)
Whether to truncate the frequencies or not.
---
## sampling
Sampling custom ops.
## `MinPSampler` {#max.nn.legacy.sampling.MinPSampler}
> class max.nn.legacy.sampling.MinPSampler(dtype, shape, temperature=1)
A min\_p sampler.
**Parameters:**
* dtype ([DType](../../dtype.md#max.dtype.DType))
* shape ([Shape](../../graph/shape.md#max.graph.shape.Shape))
* temperature ([float](https://docs.python.org/3/library/functions.html#float))
---
## sequential
A General sequential layer, each layer is executed with the outputs of the previous.
## `Sequential` {#max.nn.legacy.sequential.Sequential}
> class max.nn.legacy.sequential.Sequential(layers)
A sequential stack of layers where each layer is called by the outputs
of the previous layer.
## `take()` {#max.nn.legacy.transformer.distributed_transformer.take}
> max.nn.legacy.transformer.distributed\_transformer.take(it, n)
Return the next n items from it as a list.
**Parameters:**
* it ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](../../../graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]])
* n ([int](https://docs.python.org/3/library/functions.html#int))
---
## module
Base classes and decorators for building neural network modules in MAX.
## `Module` {#max.nn.module.Module}
> class max.nn.module.Module
The core unit of composition for modeling in MAX.
Informally, a `Module` is a container class. It can contain
other `Module` instances, tensors (the `Module`’s “local parameters”)
or other arbitrary Python data.
A `Module` also has a `forward()` method which defines how the `Module`
computes its output. In the simplest case this is a function from one tensor
to another tensor. Users call the module using `__call__()` which internally
invokes `forward()`.
Formally modules form a tree, and subtrees of modules can be manipulated
directly. A `Module` may also be thought of as a closure, where the parameters
form the data of the closure and `forward()` is the application of the closure.
Users who do not use a Python type checker, or use lax settings for their
type checker, may inherit from `Module` without parameters. Users who use
a type checker with stricter settings (including MAX internal code) should
specify explicit types for full type checking:
```default
class Linear(Module[[Tensor], Tensor]):
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
```
**Terminology:**
* A “child” of a `Module` is a sub-`Module` stored directly on that `Module`.
* A “descendant” of a `Module` is one of its children, or one of their
descendants.
* A “parameter” is a tensor storing data on the `Module` or one of its
descendants.
* The “qualified path” of a descendant is a period-separated string
of the names of the child module attributes which lead to that
descendant module, for instance `child.sub.last`.
* The “qualified path” of a parameter is the qualified path of the
descendant directly holding that parameter, followed by a final
path component for the attribute name of the tensor.
For instance `weight` for a local parameter, or
`child.sub.last.weight` for a descendant’s parameter.
```python
from max.tensor import Tensor
from max.nn import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor | int = 0
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
linear = Linear(Tensor.zeros([5, 4]))
print(linear)
print(linear(Tensor.constant([1, 2, 3, 4])))
```
### `apply_to_local_parameters()` {#max.nn.module.Module.apply_to_local_parameters}
> apply\_to\_local\_parameters(f)
Applies a transformation to each local parameter tensor on the `Module`.
The transformation is applied in-place, updating the module’s values.
It will not be applied to descendant’s parameters.
For example:
```python
from max.driver import Accelerator
from max.nn import Linear
model = Linear(2, 3)
model.apply_to_parameters(lambda _, t: t.to(Accelerator()))
```
**Parameters:**
f ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [Tensor](../tensor.md#max.tensor.Tensor)]) –
The transformation to apply to each local parameter.
The transformation takes two arguments, a name and a tensor:
* The name is the attribute name of the parameter on the module.
* The tensor is the current value of that parameter.
The return value of this function is the new value that will
replace the value at that name.
**Return type:**
None
### `apply_to_parameters()` {#max.nn.module.Module.apply_to_parameters}
> apply\_to\_parameters(f)
Applies a transformation to all parameters in the module hierarchy.
This method traverses the module tree and applies the transformation function
to each parameter in-place, updating both the current module’s parameters
and all nested sub-module parameters. The transformation receives the
parameter’s qualified name (dot-separated path) and current tensor value.
Transfer all parameters to accelerator:
```python
from max.driver import Accelerator
from max.tensor import Tensor
from max.nn import Module, module_dataclass, Linear
@module_dataclass
class MLP(Module):
fc1: Linear
fc2: Linear
def forward(self, x: Tensor) -> Tensor:
return self.fc2(self.fc1(x))
model = MLP(
fc1=Linear(10, 20),
fc2=Linear(20, 5)
)
model.apply_to_parameters(lambda name, t: t.to(Accelerator()))
```
**Parameters:**
f ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [Tensor](../tensor.md#max.tensor.Tensor)]) –
Transformation function taking `(name, tensor)` and returning
the transformed tensor. Parameters:
* `name` ([`str`](https://docs.python.org/3/library/stdtypes.html#str)): Qualified dot-separated path of the parameter
(e.g., `"fc1.weight"`, `"encoder.layer2.bias"`)
* `tensor` (`Tensor`): Current value of the parameter
Returns the new tensor value to replace the parameter.
**Return type:**
None
### `children` {#max.nn.module.Module.children}
> property children: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
Iterates over the direct child modules of the `Module`.
**Yields:**
`(name, module)` pairs, where `name` is the attribute name of
the child on the module.
### `compile()` {#max.nn.module.Module.compile}
> compile(\*input\_types, weights=None)
Compiles the module to an optimized executable through graph tracing.
This method performs symbolic tracing of the module’s `forward` method
to construct a MAX `Graph`, which is then compiled and optimized for
efficient execution on CPU, GPU, or other accelerators.
The compilation process:
1. Creates symbolic `Tensor` instances based on provided type specifications
2. Executes `forward` with symbolic tensors to record operations
3. Constructs a `Graph` representing the computation
4. Includes all module parameters as weights in the graph
5. Compiles and optimizes the graph for target hardware
6. Returns an executable function with the same signature as `forward`
The input type specifications must match the signature of `forward`.
Use positional arguments for positional parameters.
Basic compilation with fixed shapes:
```python
from max.dtype import DType
from max.tensor import Tensor, TensorType, defaults
from max.nn import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
linear = Linear(
weight=Tensor.zeros([10, 5]),
bias=Tensor.zeros([10])
)
# Compile with fixed input shape
_, device = defaults()
input_type = TensorType(DType.float32, [3, 5], device=device)
model = linear.compile(input_type)
# Execute compiled model
input_data = Tensor.ones([3, 5], dtype=DType.float32)
result = model(input_data)
print(result)
```
**Parameters:**
* \*input\_types ([Type](../graph/type.md#max.graph.type.Type)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – Type specifications for each positional argument to
`forward`. Must match the number and order of arguments.
Each should be a `max.graph.Type` (typically
`TensorType`) describing the shape and dtype.
* weights ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)] | None) – Mapping of parameter names to weight data. Weights should
be on CPU and will be transferred to the target device as part
of model initialization. If not passed, the model’s parameters
will be used as the weights.
**Returns:**
Callable\[…, Any]
A compiled executable function with the same signature as
`forward`. This function runs the optimized graph and
returns results with the same structure as `forward`
(single `Tensor` or tuple of tensors).
**Raises:**
* [TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If input types don’t match `forward` signature or if
operations in `forward` cannot be traced.
* [RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If graph construction fails due to incompatible
operations or parameter access issues.
### `descendants` {#max.nn.module.Module.descendants}
> property descendants: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
Iterates over the `Module`’s descendant modules.
**Yields:**
`(name, module)` pairs, where `name` is the qualified path
of the descendant with respect to the module.
### `forward()` {#max.nn.module.Module.forward}
> forward(\*args, \*\*kwargs)
Defines the computation performed by the module.
Users must override this method in their subclass to define the
module’s computation.
**Parameters:**
* \*args (\~\_P) – Positional arguments for the computation.
* \*\*kwargs (\~\_P) – Keyword arguments for the computation.
**Returns:**
The result of applying the module to the input.
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If the subclass does not override this method.
**Return type:**
\_R
### `load_state()` {#max.nn.module.Module.load_state}
> load\_state(lookup)
Replaces each parameter in the module and its descendants.
The transformation is applied in-place, updating the module’s values
and those of its descendants.
For example, if we have a model with two parameters, `weight` and
`bias`, we can load the state of the model from a dictionary with the
following code:
```python
from max.tensor import Tensor
from max.nn import Linear
model = Linear(2, 3)
weights = {
"weight": Tensor.zeros([3, 2]),
"bias": Tensor.zeros([3]),
}
model.load_state(lambda name, _: weights[name])
```
The lookup is defined as a function rather than a dictionary, allowing
for functional remapping of names during this process to account
for differences in common weight naming and storage conventions.
For instance, certain representations may not store weights as
transposed, or may need to be quantized, or split out from a shared
qkv block, or may just have slightly different names or paths.
This can also be used for instance to provide a default value for
initializing LoRA weights.
**Parameters:**
lookup ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [DLPackArray](../driver.md#max.driver.DLPackArray)]) –
The lookup function for each parameter:
* The first argument is the qualified name of the parameter
with respect to the module on which `load_state()` was
called.
* The second argument is the existing tensor value.
* The return value of this function is the new value that will
replace the value at that name in the module tree.
### `load_state_dict()` {#max.nn.module.Module.load_state_dict}
> load\_state\_dict(state, strict=True)
Loads parameter values from a dictionary into the module hierarchy.
This method updates all module parameters in-place by loading values from
the provided state dictionary. The dictionary maps qualified parameter names
(dot-separated paths like `"fc1.weight"`) to tensor values.
The `strict` mode (default) ensures all weights in the dictionary are
actually used, catching errors from mismatched architectures or incorrect
weight names.
For example, the following loads weights from a dictionary into a model:
```python
from max.tensor import Tensor
from max.nn import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
model = Linear(
weight=Tensor.zeros([10, 5]),
bias=Tensor.zeros([10])
)
# Load weights from dictionary
weights = {
"weight": Tensor.zeros([10, 5]),
"bias": Tensor.zeros([10]),
}
model.load_state(lambda name, _: weights[name])
```
**Parameters:**
* state ([Mapping](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [DLPackArray](../driver.md#max.driver.DLPackArray)]) – Dictionary mapping qualified parameter names to tensor values.
Keys should match the names from [`Module.parameters`](#max.nn.module.Module.parameters) property.
Values should be DLPack-compatible arrays or `Tensor` objects.
Their shapes and dtypes must match the existing parameters with the
corresponding name, but they may be on a different device. In the
case that the new value has a different device, it will be copied to
the same device as the existing value, and the parameter will be set
to the new copy.
* strict ([bool](https://docs.python.org/3/library/functions.html#bool)) – If [`True`](https://docs.python.org/3/library/constants.html#True) (default), verify that all keys in `state`
are used (i.e., match actual parameters). If [`False`](https://docs.python.org/3/library/constants.html#False), silently
ignore extra keys that don’t match any parameters.
**Raises:**
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If `strict=True` and some weights in `state` don’t
match any model parameters (indicates architecture mismatch or
incorrect weight names).
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a loaded tensor has a different dtype or shape than
the existing parameter.
* [KeyError](https://docs.python.org/3/library/exceptions.html#KeyError) – If a required parameter name in the model is missing from
`state` (regardless of `strict` setting).
**Return type:**
None
### `local_parameters` {#max.nn.module.Module.local_parameters}
> property local\_parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)]]
Iterates over the local parameters of the `Module`.
**Yields:**
`(name, tensor)` pairs, where `name` is the attribute name of
the tensor on the module.
### `map_parameters()` {#max.nn.module.Module.map_parameters}
> map\_parameters(f)
Creates a new `Module` with its parameters transformed by the function.
The transformation is functional rather than in-place. The module is
deep-copied; its descendants are also replaced via the same transform
without affecting the original module.
For example:
```python
from max.driver import Accelerator
from max.nn import Linear
model = Linear(2, 3)
model_on_gpu = model.map_parameters(lambda _, t: t.to(Accelerator()))
```
**Parameters:**
f ([Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)], [Tensor](../tensor.md#max.tensor.Tensor)]) –
The transformation to apply to each parameter.
The transformation takes two arguments, a name and a tensor:
* The name is the qualified name of the parameter
with respect to the module on which `map_parameters()`
was called.
* The tensor is the current value of that parameter.
The return value of this function is the new value that will
replace the value at that name in the module tree.
**Returns:**
A new module tree of the same type resulting from mapping the
transformation over all model parameters.
### `parameters` {#max.nn.module.Module.parameters}
> property parameters: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Tensor](../tensor.md#max.tensor.Tensor)]]
Iterates over all parameters in this module and its sub-modules.
This property performs a depth-first traversal of the module hierarchy,
yielding each parameter tensor with its qualified name. The qualified name
uses dot-notation to represent the module tree structure (e.g.,
`"encoder.layer1.weight"`).
Parameters are yielded in depth-first order: first the current module’s
direct parameters, then recursively each sub-module’s parameters.
Counting total parameters:
```python
from max.tensor import Tensor
from max.nn import Module, module_dataclass
from max.nn import Linear
@module_dataclass
class MLP(Module):
fc1: Linear
fc2: Linear
def forward(self, x: Tensor) -> Tensor:
return self.fc2(self.fc1(x))
model = MLP(
fc1=Linear(10, 20),
fc2=Linear(20, 5)
)
# Count parameters
total_params = sum(
param.num_elements()
for name, param in model.parameters
)
print(f"Total parameters: {total_params}")
```
**Yields:**
`(name, parameter)` tuples where `name` is the
dot-separated qualified path of the parameter and `parameter`
is the `Tensor`.
### `to()` {#max.nn.module.Module.to}
> to(device)
Updates the module’s parameters, transferring them to the specified device.
```python
from max.driver import CPU
from max.nn import Linear
model = Linear(2, 3)
model.to(CPU())
```
**Parameters:**
device ([Device](../driver.md#max.driver.Device)) – The device to which all model parameters will be transferred.
**Returns:**
A reference to the model. The transfer is applied mutably; internal
parameters are updated to be transferred to the specified device.
## `module_dataclass()` {#max.nn.module.module_dataclass}
> max.nn.module.module\_dataclass(cls=None, /, \*, repr=False, \*\*kwargs)
Converts a class into a MAX module with automatic parameter tracking.
This decorator enables a regular Python class to function as a [`Module`](#max.nn.module.Module),
providing automatic discovery and registration of parameters (Tensor fields)
and nested modules. The decorated class gains all capabilities of [`Module`](#max.nn.module.Module),
including parameter iteration, graph compilation via [`Module.compile()`](#max.nn.module.Module.compile),
and hierarchical module composition.
The decorator applies Python’s `@dataclass` decorator internally while
preserving [`Module`](#max.nn.module.Module)’s specialized `__repr__` method for better
debugging experience when printing module structures.
**Parameters:**
* cls ([type](https://docs.python.org/3/library/functions.html#type)\[[Module](#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | None) – The class to decorate. Must define a `forward` method.
When [`None`](https://docs.python.org/3/library/constants.html#None), returns a decorator function (supports
using `@module_dataclass` with or without parentheses).
* repr ([bool](https://docs.python.org/3/library/functions.html#bool)) – If [`True`](https://docs.python.org/3/library/constants.html#True), use dataclass’s default `__repr__` instead of
[`Module`](#max.nn.module.Module)’s rich representation. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False).
* \*\*kwargs – Additional keyword arguments forwarded to Python’s
`@dataclass` decorator (e.g., `frozen`, `eq`).
**Returns:**
The decorated class as a [`Module`](#max.nn.module.Module) subclass with automatic parameter
tracking and graph compilation capabilities. When `cls` is [`None`](https://docs.python.org/3/library/constants.html#None),
returns a decorator function.
---
## GemmaRMSNorm
## `GemmaRMSNorm` {#max.nn.norm.rms_norm.GemmaRMSNorm}
> class max.nn.norm.rms\_norm.GemmaRMSNorm(dim, eps=1e-06)
Computes the Root Mean Square normalization on inputs.
Differences to traditional RMSNorm:
* x \* (1 + w) instead of x \* w.
* (x \* w).to(orig\_dtype) instead of x.to(orig\_dtype) \* w.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int))
* eps ([float](https://docs.python.org/3/library/functions.html#float))
### `forward()` {#max.nn.norm.rms_norm.GemmaRMSNorm.forward}
> forward(x)
Defines the computation performed by the module.
Users must override this method in their subclass to define the
module’s computation.
**Parameters:**
* \*args – Positional arguments for the computation.
* \*\*kwargs – Keyword arguments for the computation.
* x ([Tensor](../../tensor.md#max.tensor.Tensor))
**Returns:**
The result of applying the module to the input.
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If the subclass does not override this method.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
---
## norm (Norm)
Normalization layers for neural networks.
## Modules
* [`rms_norm`](/max/api/python/nn/norm/rms_norm): Root Mean Square normalization layer.
## Classes
* [`GemmaRMSNorm`](/max/api/python/nn/norm/GemmaRMSNorm): RMS normalization optimized for Gemma models.
:::note Note
For legacy normalization layers (LayerNorm, GroupNorm), see [legacy/norm](/max/api/python/nn/legacy/norm).
:::
---
## rms_norm
Root mean square layer normalization.
## `RMSNorm` {#max.nn.norm.rms_norm.RMSNorm}
> class max.nn.norm.rms\_norm.RMSNorm(dim, eps=1e-06)
Computes the Root Mean Square normalization on inputs.
**Parameters:**
* dim ([int](https://docs.python.org/3/library/functions.html#int))
* eps ([float](https://docs.python.org/3/library/functions.html#float))
### `dim` {#max.nn.norm.rms_norm.RMSNorm.dim}
> property dim: [Dim](../../graph/dim.md#max.graph.dim.Dim)
### `eps` {#max.nn.norm.rms_norm.RMSNorm.eps}
> eps: [float](https://docs.python.org/3/library/functions.html#float)
### `forward()` {#max.nn.norm.rms_norm.RMSNorm.forward}
> forward(x)
Defines the computation performed by the module.
Users must override this method in their subclass to define the
module’s computation.
**Parameters:**
* \*args – Positional arguments for the computation.
* \*\*kwargs – Keyword arguments for the computation.
* x ([Tensor](../../tensor.md#max.tensor.Tensor))
**Returns:**
The result of applying the module to the input.
**Raises:**
[NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If the subclass does not override this method.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
### `weight` {#max.nn.norm.rms_norm.RMSNorm.weight}
> weight: [Tensor](../../tensor.md#max.tensor.Tensor)
## `rms_norm()` {#max.nn.norm.rms_norm.rms_norm}
> max.nn.norm.rms\_norm.rms\_norm(x, weight, eps, weight\_offset=0.0, multiply\_before\_cast=False)
Applies Root Mean Square layer normalization to an input tensor.
See
**Parameters:**
* x ([Tensor](../../tensor.md#max.tensor.Tensor)) – The input tensor
* weight ([Tensor](../../tensor.md#max.tensor.Tensor)) – The weights for the normalization
* eps ([float](https://docs.python.org/3/library/functions.html#float)) – A value added to the denominator of the normalization for
numerical stability
* weight\_offset ([float](https://docs.python.org/3/library/functions.html#float)) – A value added to the weights before normalization.
Typically 1 for Gemma-like normalization and 0 otherwise.
* multiply\_before\_cast ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to multiply before or after
casting to the output dtype. Typically True for Gemma-like
normalization and False otherwise.
**Returns:**
A layer-normalized tensor with the same shape and type as x.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
---
## RotaryEmbedding
## `RotaryEmbedding` {#max.nn.rope.RotaryEmbedding}
> class max.nn.rope.RotaryEmbedding(weight: [max.tensor.Tensor](../../tensor.md#max.tensor.Tensor))
### `dim` {#max.nn.rope.RotaryEmbedding.dim}
> property dim: [int](https://docs.python.org/3/library/functions.html#int)
### `forward()` {#max.nn.rope.RotaryEmbedding.forward}
> forward(x, start\_pos=0)
Applies rotary positional embeddings (RoPE) to x.
seq\_len is inferred from the shape of x.
**Parameters:**
* x ([Tensor](../../tensor.md#max.tensor.Tensor)) – Activation tensor with shape (batch, seq\_len, n\_kv\_heads, head\_dim).
x is interpreted as a complex number valued tensor where the
head\_dim dimension is alternating pairs of (real, imaginary)
parts.
* start\_pos ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – starting position of input tensor, defaults to 0 if None
**Returns:**
Input activation tensor with rotary positional embeddings applied and
the same shape as x.
### `forward()` {#max.nn.rope.rope.TransposedRotaryEmbedding.forward}
> forward(x, start\_pos=0)
Applies rotary positional embeddings (RoPE) to x.
The representation of x is transposed within the final dimension
compared to traditional RotaryEmbedding.
seq\_len is inferred from the shape of x.
**Parameters:**
* x ([Tensor](../../tensor.md#max.tensor.Tensor)) – Activation tensor with shape (batch, seq\_len, n\_kv\_heads, head\_dim).
x is interpreted as a complex number valued tensor where the
first half of head\_dim are the real parts and the last half
are the imaginary parts.
* start\_pos ([int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](../../graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – starting position of input tensor, defaults to 0 if None
**Returns:**
Input activation tensor with rotary positional embeddings applied and
the same shape as x.
**Return type:**
[Tensor](../../tensor.md#max.tensor.Tensor)
---
## rope
Rotary positional embedding (RoPE) implementations.
## Classes
* [`RotaryEmbedding`](/max/api/python/nn/rope/RotaryEmbedding): Standard rotary position embedding implementation.
* [`TransposedRotaryEmbedding`](/max/api/python/nn/rope/TransposedRotaryEmbedding): RoPE with transposed tensor layout.
:::note Note
For legacy RoPE variants (DynamicRotaryEmbedding, YarnRotaryEmbedding, etc.),
see [legacy/rotary\_embedding](/max/api/python/nn/legacy/rotary_embedding).
:::
---
## sequential (Nn)
:::note Note
This module contains both `Sequential` and `ModuleList` containers.
For the legacy graph-based sequential container, see [legacy/sequential](/max/api/python/nn/legacy/sequential).
:::
A Module for a sequence of tensor transformations.
## `ModuleList` {#max.nn.sequential.ModuleList}
> class max.nn.sequential.ModuleList(iterable=(), /)
A `Module` subclass which is locally a list container.
`ModuleList` instances will use the stringified integer index of their
submodules as the name of the module for the purposes of
qualified paths.
For example:
```python
from max.nn import Linear, Sequential
model = Sequential(
Linear(5, 10),
Linear(10, 5),
)
assert dict(model.parameters).keys() == {
"0.weight", "0.bias", "1.weight", "1.bias"
}
```
### `children` {#max.nn.sequential.ModuleList.children}
> property children: [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](module.md#max.nn.module.Module)\[..., [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
Iterates over the direct child modules of the `Module`.
**Yields:**
`(name, module)` pairs, where `name` is the attribute name of
the child on the module.
## `Sequential` {#max.nn.sequential.Sequential}
> class max.nn.sequential.Sequential(\*modules)
A `Module` subclass which holds a sequence of unary modules.
A unary `Module` is one whose `forward()` method has the signature:
```default
def forward(self, x: Tensor) -> Tensor: ...
```
`Sequential` is itself a unary `Module`. Its `forward()` method
computes the result of applying each of its child modules
in sequence to its input.
For example, this will apply a linear transformation up to a dimension
of 10, apply a LayerNorm, and then apply a final linear transformation
to reduce back to the input dimension of 5:
```python
from max.tensor import Tensor
from max.nn import Linear, Sequential
model = Sequential(
Linear(5, 10),
Linear(10, 5),
)
result = model(Tensor.ones([5]))
assert result.shape == [5]
```
**Parameters:**
modules (T)
### `forward()` {#max.nn.sequential.Sequential.forward}
> forward(x)
Applies the contained modules in order.
For example, this code creates a sequence of linear transformations
which each increase the dimension of the input by 5.
The input tensor must have dim 5. The intermediate applications
will result in intermediate tensors of dim 10 and 15 respectively,
and the final result will have dim 20:
```python
from max.tensor import Tensor
from max.nn import Linear, Sequential
hidden_dims = [5, 10, 15, 20]
model = Sequential(*(
Linear(in_dim, out_dim) for in_dim, out_dim in
zip(hidden_dims, hidden_dims[1:])
))
result = model(Tensor.ones([5]))
assert result.shape == [20]
```
**Parameters:**
x ([Tensor](../tensor.md#max.tensor.Tensor)) – The input tensor.
**Returns:**
The result of iteratively applying each contained
module in sequence.
**Return type:**
[Tensor](../tensor.md#max.tensor.Tensor)
---
## architectures
## `register_all_models()` {#max.pipelines.architectures.register_all_models}
> max.pipelines.architectures.register\_all\_models()
Register all built-in model architectures with the global registry.
This function imports each supported model architecture module (Llama, Mistral,
Qwen, Gemma, DeepSeek, etc.) and registers their `SupportedArchitecture`
definitions with `PIPELINE_REGISTRY`.
This function is called automatically when `max.pipelines` is imported,
so you typically don’t need to call it manually. It uses an internal flag to
ensure architectures are only registered once, making repeated calls safe but
unnecessary.
### `model_config` {#max.pipelines.lib.config.AudioGenerationConfig.model_config}
> model\_config: ClassVar\[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
### `model_post_init()` {#max.pipelines.lib.config.AudioGenerationConfig.model_post_init}
> model\_post\_init(context, /)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance.
* context (Any) – The context.
**Return type:**
None
### `prepend_prompt_speech_tokens` {#max.pipelines.lib.config.AudioGenerationConfig.prepend_prompt_speech_tokens}
> prepend\_prompt\_speech\_tokens: [PrependPromptSpeechTokens](#max.pipelines.lib.config.PrependPromptSpeechTokens)
### `prepend_prompt_speech_tokens_causal` {#max.pipelines.lib.config.AudioGenerationConfig.prepend_prompt_speech_tokens_causal}
> prepend\_prompt\_speech\_tokens\_causal: [bool](https://docs.python.org/3/library/functions.html#bool)
### `prometheus_metrics_mode` {#max.pipelines.lib.config.AudioGenerationConfig.prometheus_metrics_mode}
> prometheus\_metrics\_mode: [PrometheusMetricsMode](#max.pipelines.lib.config.PrometheusMetricsMode)
## `PipelineConfig` {#max.pipelines.lib.config.PipelineConfig}
> class max.pipelines.lib.config.PipelineConfig(\*, config\_file=None, section\_name=None, max\_length=None, pipeline\_role=PipelineRole.PrefillAndDecode, max\_batch\_size=None, max\_queue\_size\_tg=None, min\_batch\_size\_tg=None, ep\_size=1, ce\_delay\_ms=0.0, enable\_prioritize\_first\_decode=False, enable\_chunked\_prefill=True, enable\_in\_flight\_batching=False, max\_num\_steps=-1, max\_batch\_input\_tokens=8192, enable\_echo=False, pool\_embeddings=True, chat\_template=None, use\_experimental\_kernels='false', use\_vendor\_blas='false', pdl\_level='0', custom\_architectures=\, zmq\_endpoint\_base=\, execute\_empty\_batches=False, max\_batch\_total\_tokens=None, device\_graph\_capture=False, force=False, kvcache\_ce\_watermark=0.95, enable\_overlap\_scheduler=False, use\_legacy\_module=True, defer\_resolve=False, model=\, draft\_model=None, sampling=\, profiling=\, lora=None, speculative=None)
Configuration for a pipeline.
WIP - Once a PipelineConfig is fully initialized, it should be as immutable
as possible (frozen=True). All underlying dataclass fields should have been
initialized to their default values, be it user specified via some CLI
flag, config file, environment variable, or internally set to a reasonable
default.
### `log_pipeline_info()` {#max.pipelines.lib.config.PipelineConfig.log_pipeline_info}
> log\_pipeline\_info()
Logs comprehensive pipeline and KVCache configuration information.
Retrieves all necessary information from self and the PIPELINE\_REGISTRY.
Raises an error if architecture is not found (which should not happen after config resolution).
**Return type:**
None
### `lora` {#max.pipelines.lib.config.PipelineConfig.lora}
> lora: [LoRAConfig](lora_config.md#max.pipelines.lib.lora_config.LoRAConfig) | [None](https://docs.python.org/3/library/constants.html#None)
### `max_batch_input_tokens` {#max.pipelines.lib.config.PipelineConfig.max_batch_input_tokens}
> max\_batch\_input\_tokens: [int](https://docs.python.org/3/library/functions.html#int)
### `max_batch_size` {#max.pipelines.lib.config.PipelineConfig.max_batch_size}
> max\_batch\_size: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
### `max_batch_total_tokens` {#max.pipelines.lib.config.PipelineConfig.max_batch_total_tokens}
> max\_batch\_total\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
### `max_length` {#max.pipelines.lib.config.PipelineConfig.max_length}
> max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
### `max_num_steps` {#max.pipelines.lib.config.PipelineConfig.max_num_steps}
> max\_num\_steps: [int](https://docs.python.org/3/library/functions.html#int)
### `max_queue_size_tg` {#max.pipelines.lib.config.PipelineConfig.max_queue_size_tg}
> max\_queue\_size\_tg: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
### `min_batch_size_tg` {#max.pipelines.lib.config.PipelineConfig.min_batch_size_tg}
> min\_batch\_size\_tg: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)
### `model` {#max.pipelines.lib.config.PipelineConfig.model}
> model: [MAXModelConfig](model_config.md#max.pipelines.lib.model_config.MAXModelConfig)
### `model_config` {#max.pipelines.lib.config.PipelineConfig.model_config}
> model\_config: ClassVar\[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
### `model_post_init()` {#max.pipelines.lib.config.PipelineConfig.model_post_init}
> model\_post\_init(context, /)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance.
* context (Any) – The context.
**Return type:**
None
### `pdl_level` {#max.pipelines.lib.config.PipelineConfig.pdl_level}
> pdl\_level: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `pipeline_role` {#max.pipelines.lib.config.PipelineConfig.pipeline_role}
> pipeline\_role: PipelineRole
### `pool_embeddings` {#max.pipelines.lib.config.PipelineConfig.pool_embeddings}
> pool\_embeddings: [bool](https://docs.python.org/3/library/functions.html#bool)
### `profiling` {#max.pipelines.lib.config.PipelineConfig.profiling}
> profiling: ProfilingConfig
### `resolve()` {#max.pipelines.lib.config.PipelineConfig.resolve}
> resolve()
Validates and resolves the config.
Called after the config is initialized to ensure all config fields
are in a valid state.
**Return type:**
None
### `retrieve_chat_template()` {#max.pipelines.lib.config.PipelineConfig.retrieve_chat_template}
> retrieve\_chat\_template()
Returns the chat template string, or None if not set.
### `sampling` {#max.pipelines.lib.config.PipelineConfig.sampling}
> sampling: SamplingConfig
### `speculative` {#max.pipelines.lib.config.PipelineConfig.speculative}
> speculative: SpeculativeConfig | [None](https://docs.python.org/3/library/constants.html#None)
### `use_experimental_kernels` {#max.pipelines.lib.config.PipelineConfig.use_experimental_kernels}
> use\_experimental\_kernels: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `use_legacy_module` {#max.pipelines.lib.config.PipelineConfig.use_legacy_module}
> use\_legacy\_module: [bool](https://docs.python.org/3/library/functions.html#bool)
### `use_vendor_blas` {#max.pipelines.lib.config.PipelineConfig.use_vendor_blas}
> use\_vendor\_blas: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `zmq_endpoint_base` {#max.pipelines.lib.config.PipelineConfig.zmq_endpoint_base}
> zmq\_endpoint\_base: [str](https://docs.python.org/3/library/stdtypes.html#str)
## `PrependPromptSpeechTokens` {#max.pipelines.lib.config.PrependPromptSpeechTokens}
> class max.pipelines.lib.config.PrependPromptSpeechTokens(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
### `NEVER` {#max.pipelines.lib.config.PrependPromptSpeechTokens.NEVER}
> NEVER = 'never'
Never prepend the prompt speech tokens sent to the audio decoder.
### `ONCE` {#max.pipelines.lib.config.PrependPromptSpeechTokens.ONCE}
> ONCE = 'once'
Prepend the prompt speech tokens to the first block of the audio decoder.
### `ROLLING` {#max.pipelines.lib.config.PrependPromptSpeechTokens.ROLLING}
> ROLLING = 'rolling'
Prepend the prompt speech tokens to the first block of the audio decoder,
and to later blocks to reach the requested buffer size.
## `PrometheusMetricsMode` {#max.pipelines.lib.config.PrometheusMetricsMode}
> class max.pipelines.lib.config.PrometheusMetricsMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None)
### `INSTRUMENT_ONLY` {#max.pipelines.lib.config.PrometheusMetricsMode.INSTRUMENT_ONLY}
> INSTRUMENT\_ONLY = 'instrument\_only'
Instrument metrics through the Prometheus client library, relying on the application to handle the metrics server.
### `LAUNCH_MULTIPROC_SERVER` {#max.pipelines.lib.config.PrometheusMetricsMode.LAUNCH_MULTIPROC_SERVER}
> LAUNCH\_MULTIPROC\_SERVER = 'launch\_multiproc\_server'
Launch a Prometheus server in multiprocess mode to report metrics.
### `LAUNCH_SERVER` {#max.pipelines.lib.config.PrometheusMetricsMode.LAUNCH_SERVER}
> LAUNCH\_SERVER = 'launch\_server'
Launch a Prometheus server to handle metrics requests.
---
## core
## `PixelContext` {#max.pipelines.core.PixelContext}
> class max.pipelines.core.PixelContext(\*, tokens, request\_id=\, model\_name='', mask=None, tokens\_2=None, negative\_tokens=None, negative\_tokens\_2=None, extra\_params=\, timesteps=\, sigmas=\, latents=\, latent\_image\_ids=\, height=1024, width=1024, num\_inference\_steps=50, guidance\_scale=3.5, guidance=None, true\_cfg\_scale=1.0, num\_warmup\_steps=0, num\_images\_per\_prompt=1, status=GenerationStatus.ACTIVE)
A model-ready context for image/video generation requests.
Per the design doc, this class contains only numeric data that the model
will execute against. User-facing strings (prompt, negative\_prompt) are
consumed during tokenization and do not appear here.
All preprocessing is performed by PixelGenerationTokenizer.new\_context():
* Prompt tokenization -> tokens field
* Negative prompt tokenization -> negative\_tokens field
* Timestep schedule computation -> timesteps field
* Initial noise generation -> latents field
**Parameters:**
* tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)) – Tokenized prompt IDs (TokenBuffer).
* request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID)) – A unique identifier for this generation request.
* model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Name of the model being used.
* mask ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[bool](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool)]] | None)
* tokens\_2 ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None)
* negative\_tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None) – Tokenized negative prompt IDs (TokenBuffer).
* negative\_tokens\_2 ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | None)
* extra\_params ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]])
* timesteps ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) – Precomputed timestep schedule for denoising.
* sigmas ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]])
* latents ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]) – Precomputed initial noise (latents).
* latent\_image\_ids ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]])
* height ([int](https://docs.python.org/3/library/functions.html#int)) – Height of the generated image/video in pixels.
* width ([int](https://docs.python.org/3/library/functions.html#int)) – Width of the generated image/video in pixels.
* num\_inference\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Number of denoising steps.
* guidance\_scale ([float](https://docs.python.org/3/library/functions.html#float)) – Guidance scale for classifier-free guidance.
* guidance ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] | None)
* true\_cfg\_scale ([float](https://docs.python.org/3/library/functions.html#float))
* num\_warmup\_steps ([int](https://docs.python.org/3/library/functions.html#int))
* num\_images\_per\_prompt ([int](https://docs.python.org/3/library/functions.html#int)) – Number of images/videos to generate per prompt.
* status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus))
### `compute_num_available_steps()` {#max.pipelines.core.PixelContext.compute_num_available_steps}
> compute\_num\_available\_steps(max\_seq\_len)
Compute number of available steps for scheduler compatibility.
For image and video generation, this returns the number of inference steps.
### `width` {#max.pipelines.core.PixelContext.width}
> width: [int](https://docs.python.org/3/library/functions.html#int) = 1024
## `TTSContext` {#max.pipelines.core.TTSContext}
> class max.pipelines.core.TTSContext(\*, max\_length, tokens, request\_id=\, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=0, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, model\_name='', \_matcher=None, status=GenerationStatus.ACTIVE, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0, target\_endpoint=None, audio\_prompt\_tokens=\, buffer\_speech\_tokens=None, audio\_buffer=None, prev\_samples\_beyond\_offset=0, streaming=False, \_speech\_token\_size=128, \_speech\_token\_end\_idx=0, \_speech\_tokens=\, decoded\_index=0, \_block\_counter=0, \_arrival\_time=\, audio\_generation\_status=GenerationStatus.ACTIVE)
A context for Text-to-Speech (TTS) model inference.
This class extends TextContext to handle speech token generation and management.
It maintains buffers for audio prompt tokens and generated speech tokens, along
with tracking indices for decoding progress.
**Parameters:**
* max\_length ([int](https://docs.python.org/3/library/functions.html#int))
* tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer))
* request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID))
* eos\_token\_ids ([set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)])
* eos\_sequences ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]])
* log\_probabilities ([int](https://docs.python.org/3/library/functions.html#int))
* log\_probabilities\_echo ([bool](https://docs.python.org/3/library/functions.html#bool))
* ignore\_eos ([bool](https://docs.python.org/3/library/functions.html#bool))
* json\_schema ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
* sampling\_params ([SamplingParams](../interfaces.md#max.interfaces.SamplingParams))
* model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* \_matcher ([Any](https://docs.python.org/3/library/typing.html#typing.Any) | None)
* status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus))
* \_log\_probabilities\_data ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities)])
* \_is\_initial\_prompt ([bool](https://docs.python.org/3/library/functions.html#bool))
* \_draft\_offset ([int](https://docs.python.org/3/library/functions.html#int))
* target\_endpoint ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
* audio\_prompt\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Array of input audio prompt tokens used for voice cloning
* buffer\_speech\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None)
* audio\_buffer ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | None)
* prev\_samples\_beyond\_offset ([int](https://docs.python.org/3/library/functions.html#int))
* streaming ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether the request is streaming the audio to client
* \_speech\_token\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Size of the speech token buffer, defaults to SPEECH\_TOKEN\_audio\_chunk\_size
* \_speech\_token\_end\_idx ([int](https://docs.python.org/3/library/functions.html#int)) – Index marking the end of valid speech tokens
* \_speech\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Buffer containing the generated speech tokens
* decoded\_index ([int](https://docs.python.org/3/library/functions.html#int))
* \_block\_counter ([int](https://docs.python.org/3/library/functions.html#int)) – Counter tracking number of speech token blocks generated
* \_arrival\_time ([float](https://docs.python.org/3/library/functions.html#float))
* audio\_generation\_status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus))
### `audio_buffer` {#max.pipelines.core.TTSContext.audio_buffer}
> audio\_buffer: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | [None](https://docs.python.org/3/library/constants.html#None) = None
### `audio_generation_status` {#max.pipelines.core.TTSContext.audio_generation_status}
> audio\_generation\_status: [GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus) = 'active'
### `audio_prompt_tokens` {#max.pipelines.core.TTSContext.audio_prompt_tokens}
> audio\_prompt\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
### `block_counter` {#max.pipelines.core.TTSContext.block_counter}
> property block\_counter: [int](https://docs.python.org/3/library/functions.html#int)
### `buffer_speech_tokens` {#max.pipelines.core.TTSContext.buffer_speech_tokens}
> buffer\_speech\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]] | [None](https://docs.python.org/3/library/constants.html#None) = None
### `decoded_index` {#max.pipelines.core.TTSContext.decoded_index}
> decoded\_index: [int](https://docs.python.org/3/library/functions.html#int) = 0
### `is_done` {#max.pipelines.core.TTSContext.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
### `next_speech_tokens()` {#max.pipelines.core.TTSContext.next_speech_tokens}
> next\_speech\_tokens(audio\_chunk\_size=None, buffer=None)
Returns a chunk of the next unseen speech tokens.
Calling this function will not update the index of the last seen
token. This must be done by setting decoded\_index after the chunk
is processed.
**Parameters:**
* audio\_chunk\_size ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of speech tokens to return.
* buffer ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of previous speech tokens to pass to the audio
decoder on each generation step.
## `TextAndVisionContext` {#max.pipelines.core.TextAndVisionContext}
> class max.pipelines.core.TextAndVisionContext(\*, max\_length, tokens, request\_id=\, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=0, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, model\_name='', \_matcher=None, status=GenerationStatus.ACTIVE, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0, target\_endpoint=None, vision\_token\_ids, images=\, extra\_model\_args=\)
A base class for model context, specifically for Vision model variants.
For example:
```default
- = 97
- = 98
- = 99
```
Token array:
```default
- idx: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 ]
- token_ids: [ 51 52 53 54 97 98 98 98 98 99 55 56 57 58 97 98 98 98 98 99 59 60 61 62 ]
^-- img0 --^ ^-- img1 --^
^ start_idx=11 (image_idx=1)
```
Then we would have:
```default
- ImageMetadata(start_idx=5, end_idx=9, ...) # img0
- ImageMetadata(start_idx=15, end_idx=19, ...) # img1
```
These image ranges should be non-overlapping.
The image\_idx is determined based on the value of start\_idx. It is the idx of
the first image that is not yet encoded. For example in the above diagram
when start\_idx=11, this implies that image\_idx=1.
Currently we restrict start\_idx and current\_position from being in the middle of an image!
This is verified in \_validate\_state methods that are called before and after
mutating methods like \_bump\_token\_indices.
### `compute_image_aligned_idx()` {#max.pipelines.core.TextAndVisionContext.compute_image_aligned_idx}
> compute\_image\_aligned\_idx(idx)
Possibly aligns a index value downward if it lies in the middle of an image.
### `extra_model_args` {#max.pipelines.core.TextAndVisionContext.extra_model_args}
> extra\_model\_args: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
Extra model arguments for the vision model. These are model specific arguments.
### `image_idx` {#max.pipelines.core.TextAndVisionContext.image_idx}
> property image\_idx: [int](https://docs.python.org/3/library/functions.html#int)
Index of the next unencoded image in the prompt.
### `images` {#max.pipelines.core.TextAndVisionContext.images}
> images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](../interfaces.md#max.interfaces.ImageMetadata)]
Metadata about each image in the prompt.
### `needs_vision_encoding` {#max.pipelines.core.TextAndVisionContext.needs_vision_encoding}
> property needs\_vision\_encoding: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns whether vision encoding is needed for this context.
### `next_images` {#max.pipelines.core.TextAndVisionContext.next_images}
> property next\_images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[ImageMetadata](../interfaces.md#max.interfaces.ImageMetadata)]
Returns the images that are not yet encoded.
### `update()` {#max.pipelines.core.TextAndVisionContext.update}
> update(new\_token, log\_probabilities=None)
Updates the next\_tokens and extends existing tokens to include all generated tokens.
### `vision_token_ids` {#max.pipelines.core.TextAndVisionContext.vision_token_ids}
> vision\_token\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
The value of the \ special token. The reason this is a list
is primarily due to Pixtral which also has a image\_break\_token\_id.
## `TextContext` {#max.pipelines.core.TextContext}
> class max.pipelines.core.TextContext(\*, max\_length, tokens, request\_id=\, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=0, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, model\_name='', \_matcher=None, status=GenerationStatus.ACTIVE, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0, target\_endpoint=None)
A base class for model context, specifically for Text model variants.
This class manages the state and processing of text generation, including token management,
caching, and generation parameters.
**Parameters:**
* max\_length ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum allowed length of the generated sequence
* tokens ([TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)) – NumPy array containing the token IDs
* request\_id ([RequestID](../interfaces.md#max.interfaces.RequestID)) – A unique identifier for this sequence.
* eos\_token\_ids ([set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Set of token IDs that indicate end of sequence
* eos\_sequences ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]])
* log\_probabilities ([int](https://docs.python.org/3/library/functions.html#int)) – Whether to return token log probabilities
* log\_probabilities\_echo ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to return log probabilities for prompt tokens
* ignore\_eos ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to ignore end of sequence tokens and continue generating
* json\_schema ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional JSON schema for structured output
* sampling\_params ([SamplingParams](../interfaces.md#max.interfaces.SamplingParams)) – Parameters controlling the token sampling strategy
* model\_name ([str](https://docs.python.org/3/library/stdtypes.html#str))
* \_matcher ([Any](https://docs.python.org/3/library/typing.html#typing.Any) | None)
* status ([GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus))
* \_log\_probabilities\_data ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [LogProbabilities](../interfaces.md#max.interfaces.LogProbabilities)]) – Token log probabilities data
* \_is\_initial\_prompt ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether this is the initial prompt encoding
* \_draft\_offset ([int](https://docs.python.org/3/library/functions.html#int)) – Offset for draft decoding
* target\_endpoint ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional target endpoint identifier for routing requests
### `compute_num_available_steps()` {#max.pipelines.core.TextContext.compute_num_available_steps}
> compute\_num\_available\_steps(max\_seq\_len)
Compute the max number of steps we can execute for a given context
without exceeding the max\_seq\_len.
### `eos_sequences` {#max.pipelines.core.TextContext.eos_sequences}
> eos\_sequences: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]]
### `eos_token_ids` {#max.pipelines.core.TextContext.eos_token_ids}
> eos\_token\_ids: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)]
### `get_min_token_logit_mask()` {#max.pipelines.core.TextContext.get_min_token_logit_mask}
> get\_min\_token\_logit\_mask(num\_steps)
Returns a set of indices for the tokens in the output that should be masked.
This is primarily used for the min\_tokens setting, where we mask
eos tokens in the logits to avoid generating them before we reach
min\_tokens.
**Returns:**
A set of indices for the tokens in the output that should be masked.
### `ignore_eos` {#max.pipelines.core.TextContext.ignore_eos}
> ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) = False
### `is_done` {#max.pipelines.core.TextContext.is_done}
> property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool)
### `is_initial_prompt` {#max.pipelines.core.TextContext.is_initial_prompt}
> property is\_initial\_prompt: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns true if the context has not been updated with tokens.
### `json_schema` {#max.pipelines.core.TextContext.json_schema}
> json\_schema: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `jump_ahead()` {#max.pipelines.core.TextContext.jump_ahead}
> jump\_ahead(new\_token)
Updates the token array, while ensuring the new token is returned to the user.
### `request_id` {#max.pipelines.core.TextContext.request_id}
> request\_id: [RequestID](../interfaces.md#max.interfaces.RequestID)
### `reset()` {#max.pipelines.core.TextContext.reset}
> reset()
Resets the context’s state by combining all tokens into a new prompt.
### `status` {#max.pipelines.core.TextContext.status}
> status: [GenerationStatus](../interfaces.md#max.interfaces.GenerationStatus) = 'active'
### `target_endpoint` {#max.pipelines.core.TextContext.target_endpoint}
> target\_endpoint: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None
### `to_generation_output()` {#max.pipelines.core.TextContext.to_generation_output}
> to\_generation\_output()
Get completion tokens that are ready to be returned to the user.
This method retrieves tokens that have been generated but not yet
delivered to the user, along with their associated log probability data.
**Returns:**
The completion tokens and their associated
log probabilities, if available.
### `tokens` {#max.pipelines.core.TextContext.tokens}
> tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)
### `update()` {#max.pipelines.core.TextContext.update}
> update(new\_token, log\_probabilities=None)
Updates the next\_tokens and extends existing tokens to include all generated tokens.
### `update_with_future_token()` {#max.pipelines.core.TextContext.update_with_future_token}
> update\_with\_future\_token()
Append a placeholder future token to the generated tokens.
This is primarily used for overlap scheduling.
**Return type:**
None
## `reserve_token_space_for_batch()` {#max.pipelines.core.reserve_token_space_for_batch}
> max.pipelines.core.reserve\_token\_space\_for\_batch(batch, num\_tokens)
Temporarily reserves token space for each context in a batch by incrementing
the \_active\_idx and \_end\_idx attributes by num\_tokens for the duration
of the context. These indices are restored to their original values upon exit.
:param batch: List of TextContext objects to reserve space for.
:param num\_tokens: Number of tokens to reserve for each context.
## `validate_aspect_ratio_args()` {#max.pipelines.core.validate_aspect_ratio_args}
> max.pipelines.core.validate\_aspect\_ratio\_args(context)
Validates that required aspect ratio arguments are present for vision input.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If required aspect ratio arguments are missing.
**Return type:**
None
## `validate_image_grid_thw_args()` {#max.pipelines.core.validate_image_grid_thw_args}
> max.pipelines.core.validate\_image\_grid\_thw\_args(context)
Validates that image\_grid\_thw is present when vision encoding is needed.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If image\_grid\_thw is missing from extra\_model\_args when
vision encoding is needed.
**Return type:**
None
## `validate_image_shape_5d()` {#max.pipelines.core.validate_image_shape_5d}
> max.pipelines.core.validate\_image\_shape\_5d(context)
Validates that images have the expected 5-dimensional shape.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If the image shape is not 5-dimensional.
**Return type:**
None
## `validate_initial_prompt_has_image()` {#max.pipelines.core.validate_initial_prompt_has_image}
> max.pipelines.core.validate\_initial\_prompt\_has\_image(context)
Validates that initial prompts contain an image for vision models.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If the initial prompt doesn’t contain an image.
**Return type:**
None
## `validate_only_one_image()` {#max.pipelines.core.validate_only_one_image}
> max.pipelines.core.validate\_only\_one\_image(context)
Validates that at most one image is provided in the context.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If more than one image is provided.
**Return type:**
None
## `validate_requires_vision_context()` {#max.pipelines.core.validate_requires_vision_context}
> max.pipelines.core.validate\_requires\_vision\_context(context)
Validates that the context is a TextAndVisionContext.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If the context is not a TextAndVisionContext.
**Return type:**
None
## `validate_vision_position_ids()` {#max.pipelines.core.validate_vision_position_ids}
> max.pipelines.core.validate\_vision\_position\_ids(context)
Validates that vision\_position\_ids is present when vision encoding is needed.
**Parameters:**
context ([TextContext](#max.pipelines.core.TextContext) | [TextAndVisionContext](#max.pipelines.core.TextAndVisionContext)) – The context to validate.
**Raises:**
InputError – If vision\_position\_ids is missing from extra\_model\_args when
vision encoding is needed.
**Return type:**
None
---
## hf_utils
Utilities for interacting with Hugging Face Files/Repos.
## `HuggingFaceRepo` {#max.pipelines.lib.hf_utils.HuggingFaceRepo}
> class max.pipelines.lib.hf\_utils.HuggingFaceRepo(repo\_id, revision='main', trust\_remote\_code=False, repo\_type=None)
Handle for interacting with a Hugging Face repository (remote or local).
### `encoding_for_file()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.encoding_for_file}
> encoding\_for\_file(file)
Infers the supported encoding for a given weight file path.
### `file_exists()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.file_exists}
> file\_exists(filename)
Returns whether the given file exists in the repo.
### `files_for_encoding()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.files_for_encoding}
> files\_for\_encoding(encoding, weights\_format=None)
Returns paths to weight files for the given encoding (and optionally format).
### `formats_available` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.formats_available}
> property formats\_available: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat)]
Returns the weight formats available in this repo.
### `info` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.info}
> property info: ModelInfo
Returns Hugging Face model info (online repos only).
### `repo_id` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.repo_id}
> repo\_id: [str](https://docs.python.org/3/library/stdtypes.html#str)
The Hugging Face repo id. While it’s called repo\_id, it can be a HF
remote or local path altogether.
### `repo_type` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.repo_type}
> repo\_type: RepoType | [None](https://docs.python.org/3/library/constants.html#None) = None
The type of repo. This is inferred from the repo\_id.
### `revision` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.revision}
> revision: [str](https://docs.python.org/3/library/stdtypes.html#str) = 'main'
The revision to use for the repo.
### `size_of()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.size_of}
> size\_of(filename)
Returns file size in bytes for online repos, or None.
### `supported_encodings` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.supported_encodings}
> property supported\_encodings: [list](https://docs.python.org/3/library/stdtypes.html#list)\[SupportedEncoding]
Returns encodings supported by this repo’s weight files.
### `trust_remote_code` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.trust_remote_code}
> trust\_remote\_code: [bool](https://docs.python.org/3/library/functions.html#bool) = False
Whether to trust remote code.
### `weight_files` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.weight_files}
> property weight\_files: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]]
Returns weight file paths grouped by format (safetensors, gguf).
## `download_weight_files()` {#max.pipelines.lib.hf_utils.download_weight_files}
> max.pipelines.lib.hf\_utils.download\_weight\_files(huggingface\_model\_id, filenames, revision=None, force\_download=False, max\_workers=8)
Downloads weight files for a Hugging Face model and returns local paths.
**Parameters:**
* huggingface\_model\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The Hugging Face model identifier, ie. modularai/Llama-3.1-8B-Instruct-GGUF
* filenames ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]) – A list of file paths relative to the root of the Hugging Face repo.
If files provided are available locally, download is skipped, and
the local files are used.
* revision ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – The Hugging Face revision to use. If provided, we check our cache
directly without needing to go to Hugging Face directly, saving a
network call.
* force\_download ([bool](https://docs.python.org/3/library/functions.html#bool)) – A boolean, indicating whether we should force the files to be
redownloaded, even if they are already available in our local cache,
or a provided path.
* max\_workers ([int](https://docs.python.org/3/library/functions.html#int)) – The number of worker threads to concurrently download files.
## `generate_local_model_path()` {#max.pipelines.lib.hf_utils.generate_local_model_path}
> max.pipelines.lib.hf\_utils.generate\_local\_model\_path(repo\_id, revision)
Generate the local filesystem path where a Hugging Face model repo is cached.
This function uses Hugging Face’s official snapshot\_download with local\_files\_only=True
to resolve the local cache path for a model repository.
**Parameters:**
* repo\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The Hugging Face repository ID in the format “org/model”
(e.g. “HuggingFaceTB/SmolLM2-135M”)
* revision ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The specific model revision hash to use, typically from a repo lock file
**Returns:**
The absolute path to the cached model files for the specified revision.
[FileNotFoundError](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) – If the model is not found in the local cache
## `is_diffusion_pipeline()` {#max.pipelines.lib.hf_utils.is_diffusion_pipeline}
> max.pipelines.lib.hf\_utils.is\_diffusion\_pipeline(repo)
Check if a Hugging Face repository is a diffusion pipeline.
Diffusion pipelines typically have a model\_index.json file that describes
the pipeline components.
**Parameters:**
repo ([HuggingFaceRepo](#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo to check.
**Returns:**
True if the repository appears to be a diffusion pipeline, False otherwise.
## `try_to_load_from_cache()` {#max.pipelines.lib.hf_utils.try_to_load_from_cache}
> max.pipelines.lib.hf\_utils.try\_to\_load\_from\_cache(repo\_id, filename, revision)
Wrapper around `huggingface_hub.try_to_load_from_cache`; validates repo exists.
`validate_hf_repo_access` is called first to ensure the repo exists.
## `validate_hf_repo_access()` {#max.pipelines.lib.hf_utils.validate_hf_repo_access}
> max.pipelines.lib.hf\_utils.validate\_hf\_repo\_access(repo\_id, revision)
Validates repository access and raises clear, user-friendly errors.
Results are cached to avoid redundant Hugging Face API calls when the same
repository is validated multiple times within a process.
**Parameters:**
* repo\_id ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The Hugging Face repository ID to validate
* revision ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The revision/branch to validate
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – With user-friendly error messages for various access issues
**Return type:**
None
---
## pipelines
The pipelines package provides end-to-end implementations for text
generation, embeddings, audio generation, and speech processing that
automatically convert Hugging Face models into performance-optimized MAX graphs.
Each pipeline can be served via OpenAI-compatible endpoints for production
deployment.
## Modules
* [`architectures`](/max/api/python/pipelines/architectures)
* [`config`](/max/api/python/pipelines/config)
* [`core`](/max/api/python/pipelines/core)
* [`hf_utils`](/max/api/python/pipelines/hf_utils)
* [`interfaces`](/max/api/python/pipelines/interfaces)
* [`lora_config`](/max/api/python/pipelines/lora_config)
* [`model_config`](/max/api/python/pipelines/model_config)
* [`pipeline`](/max/api/python/pipelines/pipeline)
* [`registry`](/max/api/python/pipelines/registry)
* [`sampling`](/max/api/python/pipelines/sampling)
* [`tokenizer`](/max/api/python/pipelines/tokenizer)
---
## interfaces (Pipelines)
Interfaces for MAX pipelines.
## `AlwaysSignalBuffersMixin` {#max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin}
> class max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin
Bases: [`object`](https://docs.python.org/3/library/functions.html#object)
Mixin for models that always require signal buffers.
Use this for models that use VocabParallelEmbedding or other distributed
components that always perform allreduce, even on single-device setups.
Models using this mixin build graphs that always include signal buffer
inputs, regardless of device count. This is typically because they use
distributed embedding layers or other components that call allreduce
operations unconditionally.
### `devices` {#max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin.devices}
> devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](../driver.md#max.driver.Device)]
Device list that must be provided by the model class.
### `signal_buffers` {#max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin.signal_buffers}
> property signal\_buffers: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)]
Override to always create signal buffers.
Models using this mixin have distributed components that always
perform allreduce, even for single-device setups. Therefore,
signal buffers are always required to match the graph inputs.
In compile-only mode (virtual device mode), returns an empty list
to avoid GPU memory allocation which is not supported.
**Returns:**
List of signal buffer tensors, one per device, or empty list
in compile-only mode.
## `ArchConfig` {#max.pipelines.lib.interfaces.ArchConfig}
> class max.pipelines.lib.interfaces.ArchConfig(\*args, \*\*kwargs)
Bases: [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol)
Config for a model architecture.
### `get_max_seq_len()` {#max.pipelines.lib.interfaces.ArchConfig.get_max_seq_len}
> get\_max\_seq\_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the `--max-length` (`pipeline_config.max_length`) flag.
### `initialize()` {#max.pipelines.lib.interfaces.ArchConfig.initialize}
> classmethod initialize(pipeline\_config)
Initialize the config from a PipelineConfig.
## `ArchConfigWithAttentionKVCache` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache}
> class max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache(dtype, devices=\, cache\_dtype=None, kv\_cache=\, data\_parallel\_degree=1, user\_provided\_max\_length=None, huggingface\_config=None, \_kv\_params=None)
Bases: [`ArchConfigWithKVCache`](#max.pipelines.lib.interfaces.ArchConfigWithKVCache), [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC)
Predefined configuration for architectures that use attention KV cache blocks.
Subclasses must define the following attributes:
* num\_key\_value\_heads: int
* head\_dim: int
* num\_layers: int
* model\_max\_seq\_len: int
### `cache_dtype` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.cache_dtype}
> cache\_dtype: [DType](../dtype.md#max.dtype.DType) | [None](https://docs.python.org/3/library/constants.html#None) = None
The data type to use for the KV cache.
### `data_parallel_degree` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.data_parallel_degree}
> data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int) = 1
The data parallel degree to use when running the model.
### `devices` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.devices}
> devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceRef](../graph/ops.md#max.graph.ops.DeviceRef)]
The physical devices to use when running the model.
### `dtype` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.dtype}
> dtype: [DType](../dtype.md#max.dtype.DType)
The data type to use for the model.
### `get_kv_params()` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.get_kv_params}
> get\_kv\_params()
Returns the KV cache parameters for this architecture.
### `get_max_seq_len()` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.get_max_seq_len}
> get\_max\_seq\_len()
Returns the maximum sequence length the model can process.
Returns `max_length` if set, otherwise `model_max_seq_len`.
Raises ValueError if `max_length` exceeds `model_max_seq_len`.
### `kv_cache` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.kv_cache}
> kv\_cache: KVCacheConfig
The KV cache configuration to use when running the model.
### `model_max_seq_len` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.model_max_seq_len}
> abstract property model\_max\_seq\_len: [int](https://docs.python.org/3/library/functions.html#int)
The maximum sequence length that can be processed by the model.
### `num_key_value_heads` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.num_key_value_heads}
> abstract property num\_key\_value\_heads: [int](https://docs.python.org/3/library/functions.html#int)
Number of key-value heads to use for the KV cache.
### `num_layers` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.num_layers}
> abstract property num\_layers: [int](https://docs.python.org/3/library/functions.html#int)
Number of hidden layers in the model.
### `user_provided_max_length` {#max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache.user_provided_max_length}
> user\_provided\_max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None
Override for the maximum sequence length.
## `ArchConfigWithKVCache` {#max.pipelines.lib.interfaces.ArchConfigWithKVCache}
> class max.pipelines.lib.interfaces.ArchConfigWithKVCache(\*args, \*\*kwargs)
Bases: [`ArchConfig`](#max.pipelines.lib.interfaces.ArchConfig), [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol)
Config for a model architecture that uses a KV cache.
### `get_kv_params()` {#max.pipelines.lib.interfaces.ArchConfigWithKVCache.get_kv_params}
> get\_kv\_params()
KV cache parameters to use when running the model.
## `ComponentModel` {#max.pipelines.lib.interfaces.ComponentModel}
> class max.pipelines.lib.interfaces.ComponentModel(config, encoding, devices, weights)
Bases: [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC)
Base interface for component models with weight-backed execution.
## `DiffusionPipeline` {#max.pipelines.lib.interfaces.DiffusionPipeline}
> class max.pipelines.lib.interfaces.DiffusionPipeline(pipeline\_config, session, devices, weight\_paths, \*\*kwargs)
Bases: [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC)
Base class for diffusion pipelines.
Subclasses must define components mapping component names to ComponentModel types.
## `GenerateMixin` {#max.pipelines.lib.interfaces.GenerateMixin}
> class max.pipelines.lib.interfaces.GenerateMixin(\*args, \*\*kwargs)
Bases: [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol)\[`TextGenerationContextType`, `RequestType`]
Protocol for pipelines that support text generation.
### `execute()` {#max.pipelines.lib.interfaces.GenerateMixin.execute}
> execute(inputs)
Executes the pipeline for the given inputs.
### `generate_async()` {#max.pipelines.lib.interfaces.GenerateMixin.generate_async}
> async generate\_async(prompts)
Generates outputs asynchronously for the given prompts.
### `load_kv_managers()` {#max.pipelines.lib.interfaces.KVCacheMixin.load_kv_managers}
> load\_kv\_managers(kv\_params, max\_batch\_size, max\_seq\_len, session, available\_cache\_memory)
Provided a PipelineConfig and InferenceSession, loads the KV manager.
**Parameters:**
* kv\_params ([KVCacheParamInterface](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface)) – KV cache parameters.
* max\_batch\_size ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum batch size of the model.
* max\_seq\_len ([int](https://docs.python.org/3/library/functions.html#int)) – Maximum sequence length of the model.
* session ([InferenceSession](../engine.md#max.engine.InferenceSession)) – Inference session to compile and init the KV cache.
* available\_cache\_memory ([int](https://docs.python.org/3/library/functions.html#int)) – Amount of memory available to the KV cache,
in bytes.
## `ModelInputs` {#max.pipelines.lib.interfaces.ModelInputs}
> class max.pipelines.lib.interfaces.ModelInputs
Bases: [`object`](https://docs.python.org/3/library/functions.html#object)
Base class for model inputs.
Use this class to encapsulate inputs for your model; you may store any
number of dataclass fields.
The following example demonstrates how to create a custom inputs class:
```python
class ReplitInputs(ModelInputs):
tokens: Buffer
input_row_offsets: Buffer
def __init__(self, tokens: Buffer, input_row_offsets: Buffer):
self.tokens = tokens
self.input_row_offsets = input_row_offsets
tokens = Buffer.zeros((1, 2, 3), DType.int64)
input_row_offsets = Buffer.zeros((1, 1, 1), DType.int64)
# Initialize inputs
inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets)
# Access tensors
list(inputs) == [tokens, input_row_offsets] # Output: True
```
### `hidden_states` {#max.pipelines.lib.interfaces.ModelInputs.hidden_states}
> hidden\_states: [Buffer](../driver.md#max.driver.Buffer) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)] | [None](https://docs.python.org/3/library/constants.html#None) = None
Hidden states for a variable number of tokens per sequence.
For data parallel models, this can be a list of Buffers where each Buffer
contains hidden states for the sequences assigned to that device.
### `kv_cache_inputs` {#max.pipelines.lib.interfaces.ModelInputs.kv_cache_inputs}
> kv\_cache\_inputs: KVCacheInputs | [None](https://docs.python.org/3/library/constants.html#None) = None
### `lora_ids` {#max.pipelines.lib.interfaces.ModelInputs.lora_ids}
> lora\_ids: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None
Buffer containing the LoRA ids.
### `lora_ranks` {#max.pipelines.lib.interfaces.ModelInputs.lora_ranks}
> lora\_ranks: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None
Buffer containing the LoRA ranks
### `update()` {#max.pipelines.lib.interfaces.ModelInputs.update}
> update(\*\*kwargs)
Updates attributes from keyword arguments (only existing, non-None).
### `hidden_states` {#max.pipelines.lib.interfaces.ModelOutputs.hidden_states}
> hidden\_states: [Buffer](../driver.md#max.driver.Buffer) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)] | [None](https://docs.python.org/3/library/constants.html#None) = None
Hidden states for a variable number of tokens per sequence.
For data parallel models, this can be a list of Buffers where each Buffer
contains hidden states for the sequences assigned to that device.
### `logit_offsets` {#max.pipelines.lib.interfaces.ModelOutputs.logit_offsets}
> logit\_offsets: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None
Offsets to access variable length logits for each sequence.
### `logits` {#max.pipelines.lib.interfaces.ModelOutputs.logits}
> logits: [Buffer](../driver.md#max.driver.Buffer)
Logits for a variable number of tokens per sequence.
### `next_token_logits` {#max.pipelines.lib.interfaces.ModelOutputs.next_token_logits}
> next\_token\_logits: [Buffer](../driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None) = None
Logits for just the next token.
## `PipelineModel` {#max.pipelines.lib.interfaces.PipelineModel}
> class max.pipelines.lib.interfaces.PipelineModel(pipeline\_config, session, huggingface\_config, encoding, devices, kv\_cache\_config, weights, adapter, return\_logits, return\_hidden\_states=ReturnHiddenStates.NONE)
Bases: [`ABC`](https://docs.python.org/3/library/abc.html#abc.ABC), [`Generic`](https://docs.python.org/3/library/typing.html#typing.Generic)\[`BaseContextType`]
A pipeline model with setup, input preparation and execution methods.
### `calculate_max_seq_len()` {#max.pipelines.lib.interfaces.PipelineModel.calculate_max_seq_len}
> abstract classmethod calculate\_max\_seq\_len(pipeline\_config, huggingface\_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example
shows how to implement it for a Mistral model:
```python
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e
```
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Configuration for the pipeline.
* huggingface\_config (AutoConfig) – Hugging Face model configuration.
### `compute_log_probabilities()` {#max.pipelines.lib.interfaces.PipelineModel.compute_log_probabilities}
> compute\_log\_probabilities(session, model\_inputs, model\_outputs, next\_tokens, batch\_top\_n, batch\_echo)
Optional method that can be overridden to compute log probabilities.
**Parameters:**
* session ([InferenceSession](../engine.md#max.engine.InferenceSession)) – Inference session to compute log probabilities within.
* model\_inputs ([ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)) – Inputs to the model returned by
prepare\_\*\_token\_inputs().
* model\_outputs ([ModelOutputs](#max.pipelines.lib.interfaces.ModelOutputs)) – Outputs returned by execute().
* next\_tokens ([Buffer](../driver.md#max.driver.Buffer)) – Sampled tokens. Should have shape=\[batch size]
* batch\_top\_n ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Number of top log probabilities to return per input in
the batch. For any element where top\_n == 0, the
LogProbabilities is skipped.
* batch\_echo ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[bool](https://docs.python.org/3/library/functions.html#bool)]) – Whether to include input tokens in the returned log
probabilities.
### `dtype` {#max.pipelines.lib.interfaces.PipelineModel.dtype}
> property dtype: [DType](../dtype.md#max.dtype.DType)
Returns the model data type (from encoding or pipeline config).
### `estimate_activation_memory()` {#max.pipelines.lib.interfaces.PipelineModel.estimate_activation_memory}
> classmethod estimate\_activation\_memory(pipeline\_config, huggingface\_config)
Estimates the activation memory required for model execution.
This accounts for temporary memory buffers used during model execution,
such as intermediate activations and working buffers.
The default implementation returns 0 for backward compatibility.
Models with significant activation memory requirements should override
this method to provide accurate estimates.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Pipeline configuration
* huggingface\_config (AutoConfig) – Hugging Face model configuration
### `execute()` {#max.pipelines.lib.interfaces.PipelineModel.execute}
> abstract execute(model\_inputs)
Executes the graph with the given inputs.
**Parameters:**
model\_inputs ([ModelInputs](#max.pipelines.lib.interfaces.ModelInputs)) – The model inputs to execute, containing tensors and any other
required data for model execution.
**Returns:**
ModelOutputs containing the pipeline’s output tensors.
This is an abstract method that must be implemented by concrete PipelineModels
to define their specific execution logic.
### `execute_with_capture()` {#max.pipelines.lib.interfaces.PipelineModel.execute_with_capture}
> execute\_with\_capture(model\_inputs, batch\_size)
Executes the model with optional capture handling.
Subclasses can override this to integrate device graph capture/replay.
### `finalize_pipeline_config()` {#max.pipelines.lib.interfaces.PipelineModel.finalize_pipeline_config}
> classmethod finalize\_pipeline\_config(pipeline\_config)
Finalizes the pipeline configuration.
This method is called after the pipeline configuration is resolved.
It can be overridden to perform any finalization steps that are needed.
### `prepare_initial_token_inputs()` {#max.pipelines.lib.interfaces.PipelineModel.prepare_initial_token_inputs}
> abstract prepare\_initial\_token\_inputs(replica\_batches, kv\_cache\_inputs=None, return\_n\_logits=1)
Prepares the initial inputs to be passed to `.execute()`.
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and `kv_cache_inputs` (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
### `prepare_next_token_inputs()` {#max.pipelines.lib.interfaces.PipelineModel.prepare_next_token_inputs}
> abstract prepare\_next\_token\_inputs(next\_tokens, prev\_model\_inputs)
Prepares the secondary inputs to be passed to .execute().
While prepare\_initial\_token\_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
### `signal_buffers` {#max.pipelines.lib.interfaces.PipelineModel.signal_buffers}
> property signal\_buffers: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Buffer](../driver.md#max.driver.Buffer)]
Lazily initialize signal buffers for multi-GPU communication collectives.
Signal buffers are only needed during model execution, not during compilation.
By deferring their allocation, we avoid memory allocation in compile-only mode.
**Returns:**
List of signal buffer tensors, one per device for multi-device setups,
or an empty list for single-device setups or compile-only mode.
## `PixelModelInputs` {#max.pipelines.lib.interfaces.PixelModelInputs}
> class max.pipelines.lib.interfaces.PixelModelInputs(\*, tokens, tokens\_2=None, negative\_tokens=None, negative\_tokens\_2=None, extra\_params=\, timesteps=\, sigmas=\, latents=\, latent\_image\_ids=\, height=1024, width=1024, num\_inference\_steps=50, guidance\_scale=3.5, guidance=None, true\_cfg\_scale=1.0, num\_warmup\_steps=0, num\_images\_per\_prompt=1)
Bases: [`object`](https://docs.python.org/3/library/functions.html#object)
Common input container for pixel-generation models.
Provides a consistent set of fields used across multiple pixel
pipelines and models.
### `extra_params` {#max.pipelines.lib.interfaces.PixelModelInputs.extra_params}
> extra\_params: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]
A bag of model-specific numeric parameters not represented as explicit fields.
Typical uses:
* Architecture-specific knobs (e.g., cfg\_normalization arrays, scaling vectors)
* Precomputed per-step values not worth standardizing across all models
* Small numeric tensors that are easier to carry as named extras than formal fields
Values are expected to be numpy arrays (ndarray) to keep the contract consistent,
but you can relax this if your codebase needs non-array values.
### `from_context()` {#max.pipelines.lib.interfaces.PixelModelInputs.from_context}
> classmethod from\_context(context)
Build an instance from a context-like dict.
Policy:
* If a key is missing: the dataclass default applies automatically.
* If a key is present with value None: treat as missing and substitute the class default
(including subclass overrides).
### `guidance` {#max.pipelines.lib.interfaces.PixelModelInputs.guidance}
> guidance: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]] | [None](https://docs.python.org/3/library/constants.html#None) = None
Optional guidance tensor.
* Some pipelines precompute guidance weights/tensors (e.g., per-token weights, per-step weights).
* None is meaningful here: it means “no explicit guidance tensor supplied”.
* Unlike scalar fields, None is preserved (not replaced).
### `guidance_scale` {#max.pipelines.lib.interfaces.PixelModelInputs.guidance_scale}
> guidance\_scale: [float](https://docs.python.org/3/library/functions.html#float) = 3.5
Guidance scale for classifier-free guidance (CFG).
* A higher value typically increases adherence to the prompt but can reduce diversity.
* This is expected to be a real float (not None).
* If a context provides guidance\_scale=None, from\_context() substitutes the default.
### `height` {#max.pipelines.lib.interfaces.PixelModelInputs.height}
> height: [int](https://docs.python.org/3/library/functions.html#int) = 1024
Output height in pixels.
* This is a required scalar (not None).
* If a context provides height=None, from\_context() treats that as “not provided”
and substitutes this default value (or a subclass override).
### `latent_image_ids` {#max.pipelines.lib.interfaces.PixelModelInputs.latent_image_ids}
> latent\_image\_ids: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]
Optional latent image IDs / positional identifiers for latents.
* Some pipelines attach per-latent identifiers for caching, routing, or conditioning.
* Often used to avoid recomputation of image-id embeddings across steps.
* If unused, it may remain empty.
### `latents` {#max.pipelines.lib.interfaces.PixelModelInputs.latents}
> latents: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]
Initial latent noise tensor (or initial latent state).
* For diffusion/flow models, this is typically random noise seeded per request.
* Shape depends on model: commonly \[B, C, H/8, W/8] for image latents,
or \[B, T, C, H/8, W/8] for video latents.
* If your pipeline generates latents internally, you may leave it empty.
(Model-specific subclasses can enforce non-empty via \_\_post\_init\_\_.)
### `negative_tokens` {#max.pipelines.lib.interfaces.PixelModelInputs.negative_tokens}
> negative\_tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None
Negative prompt tokens for the primary encoder.
Used for classifier-free guidance (CFG) or similar conditioning schemes.
If your pipeline does not use negative prompts, leave as None.
### `negative_tokens_2` {#max.pipelines.lib.interfaces.PixelModelInputs.negative_tokens_2}
> negative\_tokens\_2: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None
Negative prompt tokens for the secondary encoder (for dual-encoder models).
If the model is single-encoder or you do not use negative prompts, leave as None.
### `num_images_per_prompt` {#max.pipelines.lib.interfaces.PixelModelInputs.num_images_per_prompt}
> num\_images\_per\_prompt: [int](https://docs.python.org/3/library/functions.html#int) = 1
Number of images/videos to generate per prompt.
* Commonly used for “same prompt, multiple samples” behavior.
* Must be > 0.
* For video generation, the naming may still be used for historical compatibility.
### `num_inference_steps` {#max.pipelines.lib.interfaces.PixelModelInputs.num_inference_steps}
> num\_inference\_steps: [int](https://docs.python.org/3/library/functions.html#int) = 50
Number of denoising/inference steps.
* This is a required scalar (not None).
* If a context provides num\_inference\_steps=None, from\_context() treats that as
“not provided” and substitutes this default value (or a subclass override).
### `num_warmup_steps` {#max.pipelines.lib.interfaces.PixelModelInputs.num_warmup_steps}
> num\_warmup\_steps: [int](https://docs.python.org/3/library/functions.html#int) = 0
Number of warmup steps.
* Used in some schedulers/pipelines to handle initial steps differently
(e.g., scheduler stabilization, cache warmup, etc.).
* Must be >= 0.
### `sigmas` {#max.pipelines.lib.interfaces.PixelModelInputs.sigmas}
> sigmas: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]
Precomputed sigma schedule for denoising.
* Usually a 1D float32 numpy array of length num\_inference\_steps
corresponding to the noise level per step.
* Some schedulers are sigma-based; others are timestep-based; some use both.
* If unused, it may remain empty unless your model subclass requires it.
### `timesteps` {#max.pipelines.lib.interfaces.PixelModelInputs.timesteps}
> timesteps: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[float32]]
Precomputed denoising timestep schedule.
* Usually a 1D float32 numpy array of length num\_inference\_steps
(exact semantics depend on your scheduler).
* If your pipeline precomputes the scheduler trajectory, you pass it here.
* Some models may not require explicit timesteps; in that case it may remain empty.
(Model-specific subclasses can enforce non-empty via \_\_post\_init\_\_.)
### `tokens` {#max.pipelines.lib.interfaces.PixelModelInputs.tokens}
> tokens: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer)
Primary encoder token buffer.
This is the main prompt representation consumed by the model’s text encoder.
Required for all models.
### `tokens_2` {#max.pipelines.lib.interfaces.PixelModelInputs.tokens_2}
> tokens\_2: [TokenBuffer](../interfaces.md#max.interfaces.TokenBuffer) | [None](https://docs.python.org/3/library/constants.html#None) = None
Secondary encoder token buffer (for dual-encoder models).
Examples: architectures that have a second text encoder stream or pooled embeddings.
If the model is single-encoder, leave as None.
### `true_cfg_scale` {#max.pipelines.lib.interfaces.PixelModelInputs.true_cfg_scale}
> true\_cfg\_scale: [float](https://docs.python.org/3/library/functions.html#float) = 1.0
“True CFG” scale used by certain pipelines/models.
* Some architectures distinguish between the user-facing guidance\_scale and an internal
scale applied to a different normalization or conditioning pathway.
* Defaults to 1.0 for pipelines that do not use this feature.
### `width` {#max.pipelines.lib.interfaces.PixelModelInputs.width}
> width: [int](https://docs.python.org/3/library/functions.html#int) = 1024
Output width in pixels.
* This is a required scalar (not None).
* If a context provides width=None, from\_context() treats that as “not provided”
and substitutes this default value (or a subclass override).
---
## log_probabilities
## `compute_log_probabilities_ragged()` {#max.pipelines.lib.log_probabilities.compute_log_probabilities_ragged}
> max.pipelines.lib.log\_probabilities.compute\_log\_probabilities\_ragged(device, model, \*, input\_row\_offsets, logits, next\_token\_logits, tokens, sampled\_tokens, batch\_top\_n, batch\_echo)
Computes the log probabilities for ragged model outputs.
**Parameters:**
* device ([Device](../driver.md#max.driver.Device)) – Device on which to do the bulk of the log probabilities
computation. A small amount of computation still occurs on the
host regardless of this setting.
* model ([Model](../engine.md#max.engine.Model)) – A compiled version of a graph from the
‘log\_probabilities\_ragged\_graph’ function.
* input\_row\_offsets ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – Token offsets into token-indexed buffers, by batch
index. Should have 1 more element than there are batches (batch n
is token indices \[input\_row\_offsets\[n], input\_row\_offsets\[n+1])).
* logits ([Buffer](../driver.md#max.driver.Buffer) | None) – (tokens, vocab\_dim) tensor full of tensor logits. Token
dimension mapped to batches using input\_row\_offsets. May be
omitted only if all ‘batch\_echo’ values are False.
* next\_token\_logits ([Buffer](../driver.md#max.driver.Buffer)) – (batch\_dim, vocab\_dim) tensor full of tensor logits
for the next token in each batch item.
* tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – (total\_tokens,) flat token array for the batch; indices
per batch given by input\_row\_offsets.
* sampled\_tokens ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]]) – (batch\_dim,) tensor of sampled token per batch
* batch\_top\_n ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Number of top log probabilities to return per input in
the batch. For any element where top\_n == 0, the
LogProbabilities is skipped.
* batch\_echo ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[bool](https://docs.python.org/3/library/functions.html#bool)]) – Whether to include input tokens in the returned log
probabilities.
**Returns:**
Computed log probabilities for each item in the batch.
## `log_probabilities_ragged_graph()` {#max.pipelines.lib.log_probabilities.log_probabilities_ragged_graph}
> max.pipelines.lib.log\_probabilities.log\_probabilities\_ragged\_graph(device, \*, levels)
Create a graph to compute log probabilities over ragged inputs.
A model obtained by this graph is a required input to
‘compute\_log\_probabilities\_ragged’.
**Parameters:**
* device ([DeviceRef](../graph/type.md#max.graph.type.DeviceRef)) – The type of device this graph will need to run on.
* levels ([int](https://docs.python.org/3/library/functions.html#int)) – log2(max\_k+1) for the desired maximum top-k you’d like to
support. To support the OpenAI API maximum of 5 logprobs, use
levels=3. Higher levels can be used to support higher k.
**Return type:**
[Graph](../graph/Graph.md#max.graph.Graph)
---
## lora_config
MAX LoRA configuration.
## `LoRAConfig` {#max.pipelines.lib.lora_config.LoRAConfig}
> class max.pipelines.lib.lora\_config.LoRAConfig(\*, config\_file=None, section\_name=None, enable\_lora=False, lora\_paths=\, max\_lora\_rank=16, max\_num\_loras=1)
### `enable_lora` {#max.pipelines.lib.lora_config.LoRAConfig.enable_lora}
> enable\_lora: [bool](https://docs.python.org/3/library/functions.html#bool)
### `lora_paths` {#max.pipelines.lib.lora_config.LoRAConfig.lora_paths}
> lora\_paths: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]
### `max_lora_rank` {#max.pipelines.lib.lora_config.LoRAConfig.max_lora_rank}
> max\_lora\_rank: [int](https://docs.python.org/3/library/functions.html#int)
### `max_num_loras` {#max.pipelines.lib.lora_config.LoRAConfig.max_num_loras}
> max\_num\_loras: [int](https://docs.python.org/3/library/functions.html#int)
### `model_config` {#max.pipelines.lib.lora_config.LoRAConfig.model_config}
> model\_config: ClassVar\[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
### `model_post_init()` {#max.pipelines.lib.lora_config.LoRAConfig.model_post_init}
> model\_post\_init(context, /)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance.
* context (Any) – The context.
### `allow_safetensors_weights_fp32_bf6_bidirectional_cast` {#max.pipelines.lib.model_config.MAXModelConfig.allow_safetensors_weights_fp32_bf6_bidirectional_cast}
> allow\_safetensors\_weights\_fp32\_bf6\_bidirectional\_cast: [bool](https://docs.python.org/3/library/functions.html#bool)
### `create_kv_cache_config()` {#max.pipelines.lib.model_config.MAXModelConfig.create_kv_cache_config}
> create\_kv\_cache\_config(\*\*kv\_cache\_kwargs)
Create and set the KV cache configuration with the given parameters.
This method creates a new KVCacheConfig from the provided keyword arguments
and automatically sets the cache\_dtype based on the model’s quantization
encoding (or any explicit override in kv\_cache\_kwargs).
**Parameters:**
\*\*kv\_cache\_kwargs – Keyword arguments to pass to KVCacheConfig constructor.
Common options include:
* cache\_strategy: The KV cache strategy (continuous, paged, etc.)
* kv\_cache\_page\_size: Number of tokens per page for paged cache
* enable\_prefix\_caching: Whether to enable prefix caching
* device\_memory\_utilization: Fraction of device memory to use
* cache\_dtype: Override for the cache data type
**Return type:**
None
### `data_parallel_degree` {#max.pipelines.lib.model_config.MAXModelConfig.data_parallel_degree}
> data\_parallel\_degree: [int](https://docs.python.org/3/library/functions.html#int)
### `default_device_spec` {#max.pipelines.lib.model_config.MAXModelConfig.default_device_spec}
> property default\_device\_spec: [DeviceSpec](../driver.md#max.driver.DeviceSpec)
Returns the default device spec for the model.
This is the first device spec in the list, used for device spec checks
throughout config validation.
**Returns:**
The default device spec for the model.
### `device_specs` {#max.pipelines.lib.model_config.MAXModelConfig.device_specs}
> device\_specs: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[DeviceSpec](../driver.md#max.driver.DeviceSpec)]
### `diffusers_config` {#max.pipelines.lib.model_config.MAXModelConfig.diffusers_config}
> property diffusers\_config: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [None](https://docs.python.org/3/library/constants.html#None)
Retrieve the diffusers config for diffusion pipelines.
Note: For multiprocessing, \_\_getstate\_\_ clears \_diffusers\_config
before pickling. Each worker process will reload the config fresh.
**Returns:**
The diffusers config dict if this is a diffusion pipeline, None otherwise.
The dict will have a structure with “\_class\_name” and “components” keys,
where each component includes “class\_name” and “config\_dict” fields.
### `force_download` {#max.pipelines.lib.model_config.MAXModelConfig.force_download}
> force\_download: [bool](https://docs.python.org/3/library/functions.html#bool)
### `generation_config` {#max.pipelines.lib.model_config.MAXModelConfig.generation_config}
> property generation\_config: GenerationConfig
Retrieve the Hugging Face GenerationConfig for this model.
This property lazily loads the GenerationConfig from the model repository
and caches it to avoid repeated remote fetches.
**Returns:**
The GenerationConfig for the model, containing generation parameters
like max\_length, temperature, top\_p, etc. If loading fails, returns
a default GenerationConfig.
### `graph_quantization_encoding` {#max.pipelines.lib.model_config.MAXModelConfig.graph_quantization_encoding}
> property graph\_quantization\_encoding: [QuantizationEncoding](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None)
Converts the CLI encoding to a MAX Graph quantization encoding.
**Returns:**
The graph quantization encoding corresponding to the CLI encoding.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no CLI encoding was specified.
### `huggingface_config` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_config}
> property huggingface\_config: AutoConfig | [None](https://docs.python.org/3/library/constants.html#None)
Returns the Hugging Face model config (loaded on first access).
### `huggingface_model_repo` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_model_repo}
> property huggingface\_model\_repo: [HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)
Returns the Hugging Face repo handle for the model.
### `huggingface_model_revision` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_model_revision}
> huggingface\_model\_revision: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `huggingface_weight_repo` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_weight_repo}
> property huggingface\_weight\_repo: [HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)
Returns the Hugging Face repo handle for weight files.
### `huggingface_weight_repo_id` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_weight_repo_id}
> property huggingface\_weight\_repo\_id: [str](https://docs.python.org/3/library/stdtypes.html#str)
Returns the Hugging Face repo ID used for weight files.
### `huggingface_weight_revision` {#max.pipelines.lib.model_config.MAXModelConfig.huggingface_weight_revision}
> huggingface\_weight\_revision: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `kv_cache` {#max.pipelines.lib.model_config.MAXModelConfig.kv_cache}
> kv\_cache: KVCacheConfig
### `model_config` {#max.pipelines.lib.model_config.MAXModelConfig.model_config}
> model\_config: ClassVar\[ConfigDict] = {'arbitrary\_types\_allowed': True}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
### `model_name` {#max.pipelines.lib.model_config.MAXModelConfig.model_name}
> property model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str)
Returns the served model name or model path.
### `model_path` {#max.pipelines.lib.model_config.MAXModelConfig.model_path}
> model\_path: [str](https://docs.python.org/3/library/stdtypes.html#str)
### `model_post_init()` {#max.pipelines.lib.model_config.MAXModelConfig.model_post_init}
> model\_post\_init(context, /)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
**Parameters:**
* self (BaseModel) – The BaseModel instance.
* context (Any) – The context.
**Return type:**
None
### `quantization_encoding` {#max.pipelines.lib.model_config.MAXModelConfig.quantization_encoding}
> quantization\_encoding: SupportedEncoding | [None](https://docs.python.org/3/library/constants.html#None)
### `resolve()` {#max.pipelines.lib.model_config.MAXModelConfig.resolve}
> resolve()
Validates and resolves the config.
This method is called after the model config is initialized, to ensure that all
config fields have been initialized to a valid state. It will also set
and update other fields which may not be determined / initialized in the
default factory.
In order:
1. Validate that the device\_specs provided are available
2. Parse the weight path(s) and initialize the \_weights\_repo\_id
**Return type:**
None
### `rope_type` {#max.pipelines.lib.model_config.MAXModelConfig.rope_type}
> rope\_type: RopeType | [None](https://docs.python.org/3/library/constants.html#None)
### `sampling_params_defaults` {#max.pipelines.lib.model_config.MAXModelConfig.sampling_params_defaults}
> property sampling\_params\_defaults: [SamplingParamsGenerationConfigDefaults](../interfaces.md#max.interfaces.SamplingParamsGenerationConfigDefaults)
Returns sampling defaults derived from the generation config.
### `served_model_name` {#max.pipelines.lib.model_config.MAXModelConfig.served_model_name}
> served\_model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None)
### `set_cache_dtype_given_quantization_encoding()` {#max.pipelines.lib.model_config.MAXModelConfig.set_cache_dtype_given_quantization_encoding}
> set\_cache\_dtype\_given\_quantization\_encoding()
Determine the KV cache dtype based on quantization encoding configuration.
The dtype is determined in the following priority order:
1. Explicit override from kv\_cache.kv\_cache\_format (if set)
2. Derived from the model’s quantization\_encoding
3. Falls back to float32 if no encoding is specified
**Returns:**
* DType.float32 for float32, q4\_k, q4\_0, q6\_k encodings
* DType.bfloat16 for bfloat16, float8\_e4m3fn, float4\_e2m1fnx2, gptq encodings
**Return type:**
The DType to use for the KV cache. Typical values are
### `trust_remote_code` {#max.pipelines.lib.model_config.MAXModelConfig.trust_remote_code}
> trust\_remote\_code: [bool](https://docs.python.org/3/library/functions.html#bool)
### `use_subgraphs` {#max.pipelines.lib.model_config.MAXModelConfig.use_subgraphs}
> use\_subgraphs: [bool](https://docs.python.org/3/library/functions.html#bool)
### `validate_and_resolve_quantization_encoding_weight_path()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_and_resolve_quantization_encoding_weight_path}
> validate\_and\_resolve\_quantization\_encoding\_weight\_path(default\_encoding)
Verifies that the quantization encoding and weight path are consistent.
**Parameters:**
* weight\_path – The path to the weight file.
* default\_encoding (SupportedEncoding) – The default encoding to use if no encoding is provided.
**Return type:**
None
### `validate_and_resolve_rope_type()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_and_resolve_rope_type}
> validate\_and\_resolve\_rope\_type(arch\_rope\_type)
Resolves rope\_type from architecture default if not set.
**Parameters:**
arch\_rope\_type (RopeType)
**Return type:**
None
### `validate_and_resolve_with_resolved_quantization_encoding()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_and_resolve_with_resolved_quantization_encoding}
> validate\_and\_resolve\_with\_resolved\_quantization\_encoding(supported\_encodings, default\_weights\_format)
Validates model path and weight path against resolved quantization encoding.
Also resolves the KV cache strategy and finalizes the encoding config.
**Parameters:**
* supported\_encodings ([dict](https://docs.python.org/3/library/stdtypes.html#dict)\[SupportedEncoding, [list](https://docs.python.org/3/library/stdtypes.html#list)\[[KVCacheStrategy](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)]]) – A dictionary of supported encodings and their corresponding KV cache strategies.
* default\_weights\_format ([WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat)) – The default weights format to use if no weights format is provided.
**Return type:**
None
### `validate_lora_compatibility()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_lora_compatibility}
> validate\_lora\_compatibility()
Validates that LoRA configuration is compatible with model settings.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If LoRA is enabled but incompatible with current model configuration.
**Return type:**
None
### `validate_multi_gpu_supported()` {#max.pipelines.lib.model_config.MAXModelConfig.validate_multi_gpu_supported}
> validate\_multi\_gpu\_supported(multi\_gpu\_supported)
Validates that the model architecture supports multi-GPU inference.
**Parameters:**
multi\_gpu\_supported ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether the model architecture supports multi-GPU inference.
**Return type:**
None
### `vision_config_overrides` {#max.pipelines.lib.model_config.MAXModelConfig.vision_config_overrides}
> vision\_config\_overrides: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), Any]
### `weight_path` {#max.pipelines.lib.model_config.MAXModelConfig.weight_path}
> weight\_path: [list](https://docs.python.org/3/library/stdtypes.html#list)\[Path]
### `weights_size()` {#max.pipelines.lib.model_config.MAXModelConfig.weights_size}
> weights\_size()
Calculates the total size in bytes of all weight files in `weight_path`.
Attempts to find the weights locally first to avoid network
calls, checking in the following order:
1. If repo\_type is `RepoType.local`, it checks if the path
in weight\_path exists directly as a local file path.
2. Otherwise, if repo\_type is `RepoType.online`, it first checks the local
Hugging Face cache using `huggingface_hub.try_to_load_from_cache()`.
If not found in the cache, it falls back to querying the Hugging Face
Hub API via `HuggingFaceRepo.size_of()`.
**Returns:**
The total size of all weight files in bytes.
**Raises:**
* [FileNotFoundError](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) – If repo\_type is `RepoType.local` and a file
specified in weight\_path is not found within the local repo
directory.
* [ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If `HuggingFaceRepo.size_of()` fails to retrieve the
file size from the Hugging Face Hub API (e.g., file metadata
not available or API error).
* [RuntimeError](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the determined repo\_type is unexpected.
## `MAXModelConfigBase` {#max.pipelines.lib.model_config.MAXModelConfigBase}
> class max.pipelines.lib.model\_config.MAXModelConfigBase(\*, config\_file=None, section\_name=None)
Bases: `ConfigFileModel`
Abstract base class for all (required) MAX model configs.
This base class is used to configure a model to use for a pipeline, but also
handy to sidestep the need to pass in optional fields when subclassing
MAXModelConfig.
### `model_config` {#max.pipelines.lib.model_config.MAXModelConfigBase.model_config}
> model\_config: ClassVar\[ConfigDict] = {'arbitrary\_types\_allowed': True}
Configuration for the model, should be a dictionary conforming to \[ConfigDict]\[pydantic.config.ConfigDict].
---
## pipeline
MAX pipeline for model inference and generation (Text Generation variant).
## `BatchInfo` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo}
> class max.pipelines.lib.pipeline\_variants.text\_generation.BatchInfo(past\_seq\_lens, seq\_lens, num\_steps)
Information about a batch of requests passed to the pipeline.
### `num_steps` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo.num_steps}
> num\_steps: [int](https://docs.python.org/3/library/functions.html#int)
Number of steps to do in the pipeline
### `past_seq_lens` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo.past_seq_lens}
> past\_seq\_lens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Coordinated list of past sequence lengths (i.e. context lengths)
### `seq_lens` {#max.pipelines.lib.pipeline_variants.text_generation.BatchInfo.seq_lens}
> seq\_lens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]
Coordinated list of sequence lengths, i.e. prompt\_len or 1
## `TextGenerationPipeline` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline}
> class max.pipelines.lib.pipeline\_variants.text\_generation.TextGenerationPipeline(pipeline\_config, pipeline\_model, eos\_token\_id, weight\_adapters, tokenizer)
Generalized token generator pipeline.
### `execute()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.execute}
> execute(inputs)
Processes the batch and returns decoded tokens.
Given a batch, executes the graph for num\_steps in a multi-step
scenario, then decodes the tokens and returns the list of decoded
tokens.
### `kv_managers` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.kv_managers}
> property kv\_managers: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]
Return the list of KV cache managers backing this pipeline.
### `pipeline_config` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.pipeline_config}
> property pipeline\_config: [PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)
Return the pipeline configuration.
### `prepare_batch()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.prepare_batch}
> prepare\_batch(batches, num\_steps)
Prepare model inputs and ancillary state for multi-step execution.
This flattens replica batches, optionally initializes constrained
decoding bitmasks, ensures KV-cache reservations, clamps `num_steps`
per context, and builds initial model inputs.
**Parameters:**
* batches ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[TextGenerationContextType]]) – Per-replica list of contexts.
* num\_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Desired number of steps to run.
**Returns:**
* ModelInputs: Prepared inputs for the first step.
* int: The clamped number of steps to run.
* Optional\[np.ndarray]: The structured decoding bitmask or None.
* list\[TextGenerationContextType]: The flattened context batch.
**Return type:**
A tuple of
### `release()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.release}
> release(request\_id)
Mark the context as complete, releasing the cache slot from the KV manager.
Note: KV cache lifecycle is now managed by the scheduler. This method
is kept for interface compatibility but is a no-op for regular pipelines.
### `tokenizer` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.tokenizer}
> property tokenizer: [PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[TextGenerationContextType, [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]], [TextGenerationRequest](../interfaces.md#max.interfaces.TextGenerationRequest)]
Return the tokenizer used for building contexts and decoding.
### `update_for_structured_output()` {#max.pipelines.lib.pipeline_variants.text_generation.TextGenerationPipeline.update_for_structured_output}
> update\_for\_structured\_output(context, bitmask, index)
Update context and logits bitmask for structured output.
If a `json_schema` is present and no matcher is set, this compiles a
grammar matcher and installs it on the context. It may also jump ahead in
generation and fills the per-request token bitmask used to constrain the
next-token distribution.
**Parameters:**
* context (TextGenerationContextType) – Request context to update.
* bitmask ([ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), ...], [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[int32]]) – Optional preallocated bitmask buffer; updated in-place.
* index ([int](https://docs.python.org/3/library/functions.html#int)) – Global position into the bitmask for this request.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If a JSON schema is provided but structured output is not
enabled via sampling configuration.
**Return type:**
None
## `StandaloneSpeculativeDecodingPipeline` {#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline}
> final class max.pipelines.lib.speculative\_decoding.StandaloneSpeculativeDecodingPipeline(pipeline\_config, pipeline\_model, eos\_token\_id, weight\_adapters, tokenizer, draft\_pipeline\_model=None, draft\_weight\_adapters=None)
Bases: `SpeculativeDecodingPipelineBase`
Standalone speculative decoding where draft model runs independently.
In this approach, the draft model generates tokens without any information
from the target model, then the target model verifies these tokens.
### `generate_draft_tokens()` {#max.pipelines.lib.speculative_decoding.StandaloneSpeculativeDecodingPipeline.generate_draft_tokens}
> generate\_draft\_tokens(batch, num\_steps, model\_inputs)
Generates draft tokens for the batch using the draft model.
### `execute()` {#max.pipelines.lib.embeddings_pipeline.EmbeddingsPipeline.execute}
> execute(inputs)
Processes the batch and returns embeddings.
Given a batch, executes the graph and returns the list of embedding
outputs per request.
### `release()` {#max.pipelines.lib.embeddings_pipeline.EmbeddingsPipeline.release}
> release(request\_id)
Releases resources for the request (no-op for embeddings).
---
## registry (Pipelines)
Model registry, for tracking various model variants.
## `PipelineRegistry` {#max.pipelines.lib.registry.PipelineRegistry}
> class max.pipelines.lib.registry.PipelineRegistry(architectures)
Registry for managing supported model architectures and their pipelines.
This class maintains a collection of [`SupportedArchitecture`](#max.pipelines.lib.registry.SupportedArchitecture)
instances, each defining how a particular model architecture should be
loaded, configured, and executed.
:::note Note
Do not instantiate this class directly. Always use the global
[`PIPELINE_REGISTRY`](#max.pipelines.lib.registry.PIPELINE_REGISTRY) singleton, which is automatically populated
with all built-in architectures when you import `max.pipelines`.
:::
Use [`PIPELINE_REGISTRY`](#max.pipelines.lib.registry.PIPELINE_REGISTRY) when you want to:
* **Register a custom architectures**: Call [`register()`](#max.pipelines.lib.registry.PipelineRegistry.register) to add a new
MAX model architecture to the registry before loading it.
* **Query supported models**: Call [`retrieve_architecture()`](#max.pipelines.lib.registry.PipelineRegistry.retrieve_architecture) to check
if a Hugging Face model repository is supported before attempting to load it.
* **Access cached configs**: Methods like [`get_active_huggingface_config()`](#max.pipelines.lib.registry.PipelineRegistry.get_active_huggingface_config) and
[`get_active_tokenizer()`](#max.pipelines.lib.registry.PipelineRegistry.get_active_tokenizer) provide cached access to model configurations and tokenizers.
### `get_active_diffusers_config()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_diffusers_config}
> get\_active\_diffusers\_config(huggingface\_repo)
Retrieves or creates a cached diffusers config for the given repository.
This method checks if the repository is a diffusion pipeline by looking for
model\_index.json. If found, it downloads and caches the config. If not found,
returns None.
**Parameters:**
huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo containing the model.
**Returns:**
The diffusers config dict if this is a diffusion pipeline, None otherwise.
### `get_active_huggingface_config()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_huggingface_config}
> get\_active\_huggingface\_config(huggingface\_repo)
Retrieves or creates a cached Hugging Face AutoConfig for the given model.
Maintains a cache of Hugging Face configurations to avoid
reloading them unnecessarily which incurs a Hugging Face Hub API call.
If a config for the given model hasn’t been loaded before, it will
create a new one using AutoConfig.from\_pretrained() with the model’s
settings.
Note: The cache key (HuggingFaceRepo) includes trust\_remote\_code in its
hash, so configs with different trust settings are cached separately.
For multiprocessing, each worker process has its own registry instance
with an empty cache, so configs are loaded fresh in each worker.
**Parameters:**
huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo containing the model.
**Returns:**
The Hugging Face configuration object for the model.
**Return type:**
AutoConfig
### `get_active_tokenizer()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_tokenizer}
> get\_active\_tokenizer(huggingface\_repo)
Retrieves or creates a cached Hugging Face AutoTokenizer for the given model.
Maintains a cache of Hugging Face tokenizers to avoid
reloading them unnecessarily which incurs a Hugging Face Hub API call.
If a tokenizer for the given model hasn’t been loaded before, it will
create a new one using AutoTokenizer.from\_pretrained() with the model’s
settings.
**Parameters:**
huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The HuggingFaceRepo containing the model.
**Returns:**
The Hugging Face tokenizer for the model.
**Return type:**
PreTrainedTokenizer | PreTrainedTokenizerFast
### `register()` {#max.pipelines.lib.registry.PipelineRegistry.register}
> register(architecture, \*, allow\_override=False)
Add new architecture to registry.
If multiple architectures share the same name but have different tasks,
they are registered in a secondary lookup table keyed by (name, task).
### `reset()` {#max.pipelines.lib.registry.PipelineRegistry.reset}
> reset()
Clears all registered architectures (mainly for tests).
**Return type:**
None
### `retrieve()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve}
> retrieve(pipeline\_config, task=PipelineTask.TEXT\_GENERATION, override\_architecture=None)
Retrieves the tokenizer and an instantiated pipeline for the config.
### `retrieve_architecture()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_architecture}
> retrieve\_architecture(huggingface\_repo, use\_legacy\_module=True, task=None)
Retrieve architecture matching the Hugging Face model config.
**Parameters:**
* huggingface\_repo ([HuggingFaceRepo](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo)) – The Hugging Face repository to match against.
* use\_legacy\_module ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to use legacy Module architecture (default=True).
When True, appends “\_Legacy” suffix to find legacy graph-based architecture.
When False, uses the standard Hugging Face architecture name for new API.
* task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask) | None) – Optional task to disambiguate when multiple architectures share the same name.
If not provided and multiple architectures share the same name, the task will
be inferred from the Hugging Face Hub’s pipeline\_tag.
**Returns:**
The matching SupportedArchitecture or None if no match found.
### `retrieve_context_type()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_context_type}
> retrieve\_context\_type(pipeline\_config, override\_architecture=None, task=None)
Retrieve the context class type associated with the architecture for the given pipeline configuration.
The context type defines how the pipeline manages request state and inputs during
model execution. Different architectures may use different context implementations
that adhere to either the TextGenerationContext or EmbeddingsContext protocol.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – The configuration for the pipeline.
* override\_architecture ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional architecture name to use instead of looking up
based on the model repository. This is useful for cases like audio generation
where the pipeline uses a different architecture (e.g., audio decoder) than
the underlying model repository.
* task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask) | None) – Optional pipeline task to disambiguate when multiple architectures share
the same name but serve different tasks.
**Returns:**
The context class type associated with the architecture, which implements
either the TextGenerationContext or EmbeddingsContext protocol.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no supported architecture is found for the given model repository
or override architecture name.
### `retrieve_factory()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_factory}
> retrieve\_factory(pipeline\_config, task=PipelineTask.TEXT\_GENERATION, override\_architecture=None)
Retrieves the tokenizer and a factory that creates the pipeline instance.
### `retrieve_pipeline_task()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_pipeline_task}
> retrieve\_pipeline\_task(pipeline\_config)
Retrieves the pipeline task for the given pipeline configuration.
**Parameters:**
pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – The configuration for the pipeline.
**Returns:**
The task associated with the architecture.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no supported architecture is found for the given model repository.
### `retrieve_tokenizer()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_tokenizer}
> retrieve\_tokenizer(pipeline\_config, override\_architecture=None, task=None)
Retrieves a tokenizer for the given pipeline configuration.
**Parameters:**
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Configuration for the pipeline
* override\_architecture ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional architecture override string
* task ([PipelineTask](../interfaces.md#max.interfaces.PipelineTask) | None) – Optional pipeline task to disambiguate when multiple
architectures share the same name but serve different tasks.
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If no architecture is found
## `SupportedArchitecture` {#max.pipelines.lib.registry.SupportedArchitecture}
> class max.pipelines.lib.registry.SupportedArchitecture(name, example\_repo\_ids, default\_encoding, supported\_encodings, pipeline\_model, task, tokenizer, default\_weights\_format, context\_type, config, rope\_type=RopeType.none, weight\_adapters=\, multi\_gpu\_supported=False, required\_arguments=\, context\_validators=\, supports\_empty\_batches=False, requires\_max\_batch\_context\_length=False)
Represents a model architecture configuration for MAX pipelines.
Defines the components and settings required to
support a specific model architecture within the MAX pipeline system.
Each SupportedArchitecture instance encapsulates the model implementation,
tokenizer, supported encodings, and other architecture-specific configuration.
New architectures should be registered into the [`PipelineRegistry`](#max.pipelines.lib.registry.PipelineRegistry)
using the [`register()`](#max.pipelines.lib.registry.PipelineRegistry.register) method.
**Example:**
```python
my_architecture = SupportedArchitecture(
name="MyModelForCausalLM", # Must match your Hugging Face model class name
example_repo_ids=[
"your-org/your-model-name", # Add example model repository IDs
],
default_encoding=SupportedEncoding.q4_k,
supported_encodings={
SupportedEncoding.q4_k: [KVCacheStrategy.PAGED],
SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED],
# Add other encodings your model supports
},
pipeline_model=MyModel,
tokenizer=TextTokenizer,
context_type=TextContext,
config=MyModelConfig, # Architecture-specific config class
default_weights_format=WeightsFormat.safetensors,
rope_type=RopeType.none,
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
# Add other weight formats if needed
},
multi_gpu_supported=True, # Set based on your implementation capabilities
required_arguments={"some_arg": True},
task=PipelineTask.TEXT_GENERATION,
)
```
### `config` {#max.pipelines.lib.registry.SupportedArchitecture.config}
> config: [type](https://docs.python.org/3/library/functions.html#type)\[[ArchConfig](interfaces.md#max.pipelines.lib.interfaces.ArchConfig)]
The architecture-specific configuration class for the model.
This class must implement the `ArchConfig` protocol, providing an
`initialize` method that creates a configuration instance from a
`PipelineConfig`. For models with KV cache, this should be a class
implementing `ArchConfigWithKVCache` to enable KV cache memory estimation.
### `context_type` {#max.pipelines.lib.registry.SupportedArchitecture.context_type}
> context\_type: [type](https://docs.python.org/3/library/functions.html#type)\[[TextGenerationContext](../interfaces.md#max.interfaces.TextGenerationContext)] | [type](https://docs.python.org/3/library/functions.html#type)\[[EmbeddingsContext](../interfaces.md#max.interfaces.EmbeddingsContext)]
The context class type that this architecture uses for managing request state and inputs.
This should be a class (not an instance) that implements either the TextGenerationContext
or EmbeddingsContext protocol, defining how the pipeline processes and tracks requests.
### `context_validators` {#max.pipelines.lib.registry.SupportedArchitecture.context_validators}
> context\_validators: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[TextContext](core.md#max.pipelines.core.TextContext) | [TextAndVisionContext](core.md#max.pipelines.core.TextAndVisionContext)], [None](https://docs.python.org/3/library/constants.html#None)]]
A list of callable validators that verify context inputs before model execution.
These validators are called during context creation to ensure inputs meet
model-specific requirements. Validators should raise InputError for invalid
inputs, providing early error detection before expensive model operations.
```python
def validate_single_image(context: TextContext | TextAndVisionContext) -> None:
if isinstance(context, TextAndVisionContext):
if context.pixel_values and len(context.pixel_values) > 1:
raise InputError(f"Model supports only 1 image, got {len(context.pixel_values)}")
my_architecture = SupportedArchitecture(
# ... other fields ...
context_validators=[validate_single_image],
)
```
### `default_encoding` {#max.pipelines.lib.registry.SupportedArchitecture.default_encoding}
> default\_encoding: SupportedEncoding
The default quantization encoding to use when no specific encoding is requested.
### `default_weights_format` {#max.pipelines.lib.registry.SupportedArchitecture.default_weights_format}
> default\_weights\_format: [WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat)
The weights format expected by the pipeline\_model.
### `example_repo_ids` {#max.pipelines.lib.registry.SupportedArchitecture.example_repo_ids}
> example\_repo\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]
A list of Hugging Face repository IDs that use this architecture for testing and validation purposes.
### `multi_gpu_supported` {#max.pipelines.lib.registry.SupportedArchitecture.multi_gpu_supported}
> multi\_gpu\_supported: [bool](https://docs.python.org/3/library/functions.html#bool) = False
Whether the architecture supports multi-GPU execution.
### `name` {#max.pipelines.lib.registry.SupportedArchitecture.name}
> name: [str](https://docs.python.org/3/library/stdtypes.html#str)
The name of the model architecture that must match the Hugging Face model class name.
### `pipeline_model` {#max.pipelines.lib.registry.SupportedArchitecture.pipeline_model}
> pipeline\_model: [type](https://docs.python.org/3/library/functions.html#type)\[[PipelineModel](interfaces.md#max.pipelines.lib.interfaces.PipelineModel)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
The PipelineModel class that defines the model graph structure and execution logic.
### `required_arguments` {#max.pipelines.lib.registry.SupportedArchitecture.required_arguments}
> required\_arguments: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [bool](https://docs.python.org/3/library/functions.html#bool) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float)]
A dictionary specifying required values for PipelineConfig options.
### `requires_max_batch_context_length` {#max.pipelines.lib.registry.SupportedArchitecture.requires_max_batch_context_length}
> requires\_max\_batch\_context\_length: [bool](https://docs.python.org/3/library/functions.html#bool) = False
Whether the architecture requires a max batch context length to be specified.
If True and max\_batch\_context\_length is not specified, we will default to
the max sequence length of the model.
### `rope_type` {#max.pipelines.lib.registry.SupportedArchitecture.rope_type}
> rope\_type: RopeType = 'none'
The type of RoPE (Rotary Position Embedding) used by the model.
### `supported_encodings` {#max.pipelines.lib.registry.SupportedArchitecture.supported_encodings}
> supported\_encodings: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[SupportedEncoding, [list](https://docs.python.org/3/library/stdtypes.html#list)\[[KVCacheStrategy](../nn/legacy/kv_cache/cache_params.md#max.nn.legacy.kv_cache.cache_params.KVCacheStrategy)]]
A dictionary mapping supported quantization encodings to their compatible KV cache strategies.
### `supports_empty_batches` {#max.pipelines.lib.registry.SupportedArchitecture.supports_empty_batches}
> supports\_empty\_batches: [bool](https://docs.python.org/3/library/functions.html#bool) = False
Whether the architecture can handle empty batches during inference.
When set to True, the pipeline can process requests with zero-sized batches
without errors. This is useful for certain execution modes and expert parallelism.
Most architectures do not require empty batch support and should leave this as False.
### `task` {#max.pipelines.lib.registry.SupportedArchitecture.task}
> task: [PipelineTask](../interfaces.md#max.interfaces.PipelineTask)
The pipeline task type that this architecture supports.
### `tokenizer` {#max.pipelines.lib.registry.SupportedArchitecture.tokenizer}
> tokenizer: [Callable](../graph/ops.md#max.graph.ops.Callable)\[\[...], [PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
A callable that returns a PipelineTokenizer instance for preprocessing model inputs.
### `tokenizer_cls` {#max.pipelines.lib.registry.SupportedArchitecture.tokenizer_cls}
> property tokenizer\_cls: [type](https://docs.python.org/3/library/functions.html#type)\[[PipelineTokenizer](../interfaces.md#max.interfaces.PipelineTokenizer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]]
Returns the tokenizer class for this architecture.
### `weight_adapters` {#max.pipelines.lib.registry.SupportedArchitecture.weight_adapters}
> weight\_adapters: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), [Callable](../graph/ops.md#max.graph.ops.Callable)\[\[...], [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [WeightData](../graph/weights.md#max.graph.weights.WeightData)]]]
A dictionary of weight format adapters for converting checkpoints from different formats to the default format.
## `get_pipeline_for_task()` {#max.pipelines.lib.registry.get_pipeline_for_task}
> max.pipelines.lib.registry.get\_pipeline\_for\_task(task, pipeline\_config)
## `PIPELINE_REGISTRY` {#max.pipelines.lib.registry.PIPELINE_REGISTRY}
> max.pipelines.lib.registry.PIPELINE\_REGISTRY: [PipelineRegistry](#max.pipelines.lib.registry.PipelineRegistry)
Global registry of supported model architectures.
This is the singleton [`PipelineRegistry`](#max.pipelines.lib.registry.PipelineRegistry) instance you can use to
register new MAX model architectures and query supported models.
---
## sampling (Pipelines)
## `rejection_sampler()` {#max.pipelines.lib.sampling.sampling.rejection_sampler}
> max.pipelines.lib.sampling.sampling.rejection\_sampler(device, \*, seed=0)
## `rejection_sampler_with_residuals()` {#max.pipelines.lib.sampling.sampling.rejection_sampler_with_residuals}
> max.pipelines.lib.sampling.sampling.rejection\_sampler\_with\_residuals(device, \*, seed=0, debug=False)
Builds a rejection sampler with residual sampling for speculative decoding.
Computes acceptance ratios for draft tokens, finds first rejection,
samples from residual distribution (target - draft), and generates bonus
tokens.
### `apply_chat_template()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.apply_chat_template}
> apply\_chat\_template(messages)
Applies the delegate’s chat template to the messages.
### `decode()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.decode}
> async decode(encoded, \*\*kwargs)
Decodes token ids to text via the delegate.
### `encode()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.encode}
> async encode(prompt, add\_special\_tokens=False)
Encodes the prompt to token ids via the delegate.
### `apply_chat_template()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.apply_chat_template}
> apply\_chat\_template(messages)
Applies the processor’s chat template to the messages.
### `encode()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.encode}
> async encode(prompt, add\_special\_tokens=True)
Transforms the provided prompt into a token array.
### `eos` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.eos}
> property eos: [int](https://docs.python.org/3/library/functions.html#int)
Returns the end-of-sequence token ID from the delegate.
### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.expects_content_wrapping}
> property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns whether this tokenizer expects content wrapping.
### `new_context()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.new_context}
> async new\_context(request)
Create a new TextAndVisionContext object, leveraging necessary information from TextGenerationRequest.
## `TextTokenizer` {#max.pipelines.lib.tokenizer.TextTokenizer}
> class max.pipelines.lib.tokenizer.TextTokenizer(model\_path, pipeline\_config, \*, revision=None, max\_length=None, trust\_remote\_code=False, enable\_llama\_whitespace\_fix=False, chat\_template=None, context\_validators=None, \*\*unused\_kwargs)
Encapsulates creation of TextContext and specific token encode/decode logic.
**Parameters:**
* model\_path ([str](https://docs.python.org/3/library/stdtypes.html#str)) – Path to the model/tokenizer
* revision ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Git revision/branch to use
* max\_length ([int](https://docs.python.org/3/library/functions.html#int) | None) – Maximum sequence length
* trust\_remote\_code ([bool](https://docs.python.org/3/library/functions.html#bool)) – Whether to trust remote code from the model
* enable\_llama\_whitespace\_fix ([bool](https://docs.python.org/3/library/functions.html#bool)) – Enable whitespace fix for Llama tokenizers
* pipeline\_config ([PipelineConfig](config.md#max.pipelines.lib.config.PipelineConfig)) – Optional pipeline configuration
* chat\_template ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional custom chat template string to override the one
shipped with the Hugging Face model config. This allows
customizing the prompt formatting for different use cases.
* context\_validators ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[Callable](../graph/ops.md#max.graph.ops.Callable)\[\[[TextContext](core.md#max.pipelines.core.TextContext)], None]] | None)
### `apply_chat_template()` {#max.pipelines.lib.tokenizer.TextTokenizer.apply_chat_template}
> apply\_chat\_template(messages, tools, chat\_template\_options=None)
Applies the delegate chat template to messages (and optional tools).
### `encode()` {#max.pipelines.lib.tokenizer.TextTokenizer.encode}
> async encode(prompt, add\_special\_tokens=True)
Transforms the provided prompt into a token array.
### `eos` {#max.pipelines.lib.tokenizer.TextTokenizer.eos}
> property eos: [int](https://docs.python.org/3/library/functions.html#int)
Returns the end-of-sequence token ID from the delegate.
### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.TextTokenizer.expects_content_wrapping}
> property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool)
Returns whether this tokenizer expects content wrapping.
### `new_context()` {#max.pipelines.lib.tokenizer.TextTokenizer.new_context}
> async new\_context(request)
Create a new TextContext object, leveraging necessary information from TextGenerationRequest.
## `max_tokens_to_generate()` {#max.pipelines.lib.tokenizer.max_tokens_to_generate}
> max.pipelines.lib.tokenizer.max\_tokens\_to\_generate(prompt\_size, max\_length, max\_new\_tokens=None)
Returns the max number of new tokens to generate.
---
## profiler
Performance profiling and tracing utilities for MAX.
This module provides tools for profiling and tracing MAX operations to analyze
performance characteristics. Profiling captures timing information for code
execution, which helps identify bottlenecks and optimize your models.
To enable profiling, set the `MODULAR_ENABLE_PROFILING=1` environment
variable before running your code. Without this variable, profiling calls will
be no-ops with minimal overhead.
The profiler supports three usage patterns:
1. **Context manager**: Use [`Tracer`](#max.profiler.Tracer) as a context manager to profile a
code block.
2. **Decorator**: Use [`@traced`](#max.profiler.traced) to profile entire functions.
3. **Manual stack**: Use [`Tracer`](#max.profiler.Tracer) methods to explicitly control profiling
spans.
## `Tracer` {#max.profiler.Tracer}
> class max.profiler.Tracer(message=None, color='modular\_purple')
A stack-based profiling manager for creating nested profiling spans.
Manages a stack of profiling spans that allows for nested tracing without
requiring deeply nested `with Trace(name):` statements. This is especially
useful when you need to dynamically create and manage profiling spans based
on runtime conditions or when profiling spans don’t align with your code’s
block structure.
The `Tracer` can be used both as a context manager and as a manual stack
manager. As a context manager, it ensures all pushed spans are properly
closed when the context exits.
```python
from max.profiler import Tracer
tracer = Tracer("parent_operation", color="modular_purple")
tracer.push("child_operation")
# ... perform work ...
tracer.pop()
# Context manager with manual stack
with Tracer("parent_operation", color="modular_purple") as tracer:
# The parent span is named "parent_operation"
tracer.push("child_operation")
# ... perform work ...
tracer.pop()
# All spans are automatically closed on context exit
```
**Parameters:**
* message ([str](https://docs.python.org/3/library/stdtypes.html#str) | None)
* color ([str](https://docs.python.org/3/library/stdtypes.html#str))
### `cleanup()` {#max.profiler.Tracer.cleanup}
> cleanup()
Closes all remaining profiling spans.
Pops and closes all profiling spans that were pushed onto the stack.
This method is automatically called when the tracer is used as a
context manager or when the object is deleted.
**Return type:**
None
### `mark()` {#max.profiler.Tracer.mark}
> mark()
Marks the current profiling span with a timestamp.
Records a timestamp event within the current profiling span. This is
useful for marking significant events or milestones within a longer
operation.
**Raises:**
[AssertionError](https://docs.python.org/3/library/exceptions.html#AssertionError) – If the stack is empty when mark is called.
**Return type:**
None
### `next()` {#max.profiler.Tracer.next}
> next(message, color='modular\_purple')
Transitions to the next profiling span.
Pops the current profiling span and immediately pushes a new one with
the specified message. This is a convenience method for sequential
operations at the same nesting level.
**Parameters:**
* message ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The name of the new profiling span.
* color ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The color of the profiling span for visualization tools.
**Return type:**
None
### `pop()` {#max.profiler.Tracer.pop}
> pop(exc\_type=None, exc\_value=None, traceback=None)
Pops a profiling span off the stack and closes it.
Removes the most recently pushed profiling span from the stack and
closes it, recording its execution time. Exception information can be
passed through for proper error handling in context managers.
**Parameters:**
* exc\_type ([type](https://docs.python.org/3/library/functions.html#type)\[[BaseException](https://docs.python.org/3/library/exceptions.html#BaseException)] | None) – The exception type if an exception occurred, or None.
* exc\_value ([BaseException](https://docs.python.org/3/library/exceptions.html#BaseException) | None) – The exception instance if an exception occurred, or None.
* traceback ([TracebackType](https://docs.python.org/3/library/types.html#types.TracebackType) | None) – The traceback object if an exception occurred, or None.
**Return type:**
None
### `push()` {#max.profiler.Tracer.push}
> push(message=None, color='modular\_purple')
Pushes a new profiling span onto the stack.
Creates and activates a new profiling span. If profiling is disabled or
no message is provided, pushes a None placeholder to maintain stack
consistency.
**Parameters:**
* message ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – The name of the profiling span. If None, no span is created.
* color ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The color of the profiling span for visualization tools.
**Return type:**
None
## `traced()` {#max.profiler.traced}
> max.profiler.traced(func=None, \*, message=None, color='modular\_purple')
Decorator for creating a profiling span for a function.
Creates a profiling span that measures the execution time of the decorated
function. This is useful for identifying performance bottlenecks without
modifying the function’s internal code. The decorator supports both
synchronous and asynchronous functions.
```python
from max.profiler import traced
# Decorator with custom span name
@traced(message="inference", color="red")
def run_model() -> None:
# The profiling span is named "inference"
model.execute()
# Decorator with default span name (uses function name)
@traced
def preprocess_data() -> None:
# The profiling span is named "preprocess_data"
data.normalize()
```
**Parameters:**
* func (\_FuncType | None) – The function to profile.
* message ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – The name of the profiling span. If None, uses the function name.
* color ([str](https://docs.python.org/3/library/stdtypes.html#str)) – The color of the profiling span for visualization tools.
**Returns:**
The decorated function wrapped in a trace object.
**Return type:**
[Callable](graph/ops.md#max.graph.ops.Callable)
---
## random
Provides random tensor generation utilities.
This module provides functions for generating random tensors with various
distributions. All functions support specifying data type and device,
with sensible defaults based on the target device.
You can generate random tensors using different distributions:
```default
from max import random
from max.dtype import DType
from max.driver import CPU
tensor1 = random.uniform((2, 3), dtype=DType.float32, device=CPU())
tensor2 = random.uniform((4, 4), range=(0, 1), dtype=DType.float32, device=CPU())
```
## `gaussian()` {#max.random.gaussian}
> max.random.gaussian(shape=(), mean=0.0, std=1.0, \*, dtype=None, device=None)
Creates a tensor filled with random values from a Gaussian (normal) distribution.
Generates a tensor with values sampled from a normal (Gaussian) distribution
with the specified mean and standard deviation. This is commonly used for
weight initialization using techniques like Xavier/Glorot or He initialization.
Create tensors with random values from a Gaussian distribution:
```default
from max import random
from max.driver import CPU
from max.dtype import DType
# Standard normal distribution
tensor = random.gaussian((2, 3), dtype=DType.float32, device=CPU())
```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Defaults to scalar (empty tuple).
* mean ([float](https://docs.python.org/3/library/functions.html#float)) – The mean (center) of the Gaussian distribution. This determines
where the distribution is centered. Defaults to `0.0`.
* std ([float](https://docs.python.org/3/library/functions.html#float)) – The standard deviation (spread) of the Gaussian distribution.
Must be positive. Larger values create more spread in the distribution.
Defaults to `1.0`.
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type of the output tensor. If `None`, uses the
default dtype for the specified device (float32 for CPU,
bfloat16 for accelerators). Defaults to `None`.
* device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If `None`,
uses the default device (accelerator if available, otherwise CPU).
Defaults to `None`.
**Returns:**
A [`Tensor`](tensor.md#max.tensor.Tensor) with random values sampled from
the Gaussian distribution.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If std <= 0.
## `normal()` {#max.random.normal}
> max.random.normal(shape=(), mean=0.0, std=1.0, \*, dtype=None, device=None)
Alias for [`gaussian()`](#max.random.gaussian).
Creates a tensor with values from a normal (Gaussian) distribution.
## `seed()` {#max.random.seed}
> max.random.seed()
Gets the global random seed tensor.
Returns the global seed tensor used for random number generation in eager
execution mode. Creates the seed tensor on first access, initialized with
the dtype, shape, and device specified by `ops.random.SeedType`.
**Returns:**
The global seed tensor for random number generation.
**Return type:**
[Tensor](tensor.md#max.tensor.Tensor)
## `set_seed()` {#max.random.set_seed}
> max.random.set\_seed(value)
Sets the global random seed value.
Updates the global random seed to the specified value. This affects all
subsequent random number generation in eager execution mode.
**Parameters:**
value ([int](https://docs.python.org/3/library/functions.html#int)) – The integer seed value to set.
**Return type:**
None
## `uniform()` {#max.random.uniform}
> max.random.uniform(shape=(), range=(0, 1), \*, dtype=None, device=None)
Creates a tensor filled with random values from a uniform distribution.
Generates a tensor with values uniformly distributed between the specified
minimum and maximum bounds. This is useful for initializing weights,
generating random inputs, or creating noise.
Create tensors with uniform random values:
```default
from max import random
from max.dtype import DType
from max.driver import CPU
# Generate 2x3 tensor with values between 0 and 1
tensor1 = random.uniform((2, 3), dtype=DType.float32, device=CPU())
tensor2 = random.uniform((4, 4), range=(0, 1), dtype=DType.float32, device=CPU())
```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Defaults to scalar (empty tuple).
* range ([tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[float](https://docs.python.org/3/library/functions.html#float), [float](https://docs.python.org/3/library/functions.html#float)]) – A tuple specifying the (min, max) bounds of the uniform
distribution. The minimum value is inclusive, the maximum value
is exclusive. Defaults to `(0, 1)`.
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type of the output tensor. If `None`, uses the
default dtype for the specified device (float32 for CPU,
bfloat16 for accelerators). Defaults to `None`.
* device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If `None`,
uses the default device (accelerator if available, otherwise CPU).
Defaults to `None`.
**Returns:**
A [`Tensor`](tensor.md#max.tensor.Tensor) with random values sampled from
the uniform distribution.
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the range tuple does not contain exactly two values
or if min >= max.
---
## tensor (Python)
Provides tensor operations with eager execution capabilities.
This module provides the [`Tensor`](#max.tensor.Tensor) class which supports
eager execution of tensor operations, complementing the graph-based execution
model provided by `graph`. The tensor operations automatically compile
and execute using the MAX runtime.
**Key Features:**
* **Eager semantics**: Operations give immediate results for quick iteration and feedback.
* **High performance**: All operations use high-performance Mojo implementations
compiled specifically for the available hardware.
* **Automatic compilation**: Tensors are compiled and optimized automatically.
Operations may be easily fused into larger graphs to take advantage of
the graph compiler’s automatic fusions.
* **Lazy evaluation**: Tensors may be computed lazily until their values are needed.
* **Familiar API**: Supports common array operations and indexing.
:::note Note
Tensors use lazy evaluation and JIT compilation, which incurs compilation
overhead on first execution. This can result in higher latency for initial
operations compared to eager frameworks like NumPy or PyTorch. Subsequent
executions reuse compiled kernels for better performance.
:::
Create and manipulate tensors with automatic compilation and optimization:
```python
from max.tensor import Tensor
from max.driver import CPU
from max.dtype import DType
x = Tensor.ones((2, 3), dtype=DType.float32, device=CPU())
y = Tensor.zeros_like(x)
result = x + y # Eager execution with automatic compilation
```
Operations may be combined into a single execution graph to take advantage
of automatic kernel fusion:
```python
from max import functional as F
@F.functional
def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
return x @ weight.T + bias
# Create and operate on tensors
x = Tensor.ones([2, 3])
weight = Tensor.ones([6, 3])
bias = Tensor.ones([6])
# Eager execution with a single fused graph
result = linear(x, weight, bias)
```
Users may opt in to lazy execution. This is primarily useful for
1. Operations which may never execute, for instance creating modules
with randomly initialized weights before loading weights
2. Combining many operations into a single execution
```python
from max.nn import Linear
with F.lazy():
model = Linear(2, 3)
print(model) # Lazy weights not initialized
# Load pretrained weights
weights = {
"weight": Tensor.zeros([3, 2]),
"bias": Tensor.zeros([3]),
}
model.load_state_dict(weights)
# Or compile directly without ever initializing weights
from max.graph import TensorType
input_type = TensorType(DType.float32, ["batch", 2], CPU())
model = model.compile(input_type, weights=weights)
```
## `RealizationContext` {#max.tensor.RealizationContext}
> class max.tensor.RealizationContext(\*args, \*\*kwargs)
Implements a way to realize unrealized tensors.
Most users should never have to think about the existence of this type.
It exists to facilitate optimizations around where and when tensor
operations are executed.
* Each tensor is either real or associated with a RealizationContext.
* If a tensor is not real, ie. “unrealized”, then it is backed by some
symbolic computation.
* The RealizationContext is responsible for tracking this symbolic
computation and “realizing” the tensor (executing the computation and
backing the tensor with real data) if and when it is asked to do so.
* A RealizationContext can only realize tensors associated with it.
RealizationContext abstracts over various semantics of tensor construction.
**“Eager” execution**: tensors are realized as soon as the realization context
exits. This is the default behavior.
This has a huge concrete advantage over eagerly executing one operation
at a time: by controlling the boundary of where the eager context starts
and ends, we can give advanced users a tool to \_enable fine-grained
bounds for automatic fusion!
In practice the easiest way to do this is to mark a function as
F.functional. This function is then assumed to be “atomic” for the
purposes of eager execution. All ops within the function execute as
part of the same graph, meaning the compiler is free to fuse operations
and generate fused kernels within this region.
**“Lazy” execution**: tensors are realized only when code later tries to use
them.
This enables a class of interface design common in the ML world, in
which layers are constructed with randomized weights which are never
used. Lazy execution neatly allows constructing entire models,
only performing the weight initialization and allocating memory for
them if and when those weights are actually used.
**Graph compilation**: tensors may never be realized.
This allows tensor operations to be composed with direct usage of
the Graph API, for instance Module.compile, or using F.\* operations
in another Graph API usage.
**Async execution**: Tensors are realized as async functions,
allowing clean integration in async systems like web services.
### `add_source()` {#max.tensor.RealizationContext.add_source}
> add\_source(tensor)
Adds a realized tensor as a “source” of the realization state,
ie. one on whose values unrealized tensors depend.
**Parameters:**
tensor ([Tensor](#max.tensor.Tensor)) – The realized tensor to add as a source to the computation.
**Returns:**
A realization state for the tensor. This may be used to compute
downstream unrealized values. \_If it is used in any mutating
operations, it should be assigned to tensor.state to mark
the tensor as having been mutated.
**Return type:**
[RealizationState](#max.tensor.RealizationState)
### `create_unrealized()` {#max.tensor.RealizationContext.create_unrealized}
> create\_unrealized(value)
Registers an unrealized graph value with the realization context
and returns it as an unrealized tensor.
**Parameters:**
value ([BufferValue](graph/BufferValue.md#max.graph.BufferValue) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue)) – The graph value representing the result of a computation.
**Returns:**
A new tensor associated with the unrealized value.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `graph` {#max.tensor.RealizationContext.graph}
> graph: [Graph](graph/Graph.md#max.graph.Graph)
The graph used by the realization context.
### `realize_all()` {#max.tensor.RealizationContext.realize_all}
> async realize\_all()
Realizes all unrealized tensors associated with this context.
## `RealizationState` {#max.tensor.RealizationState}
> class max.tensor.RealizationState(value, ctx)
State for an unrealized tensor.
See [`RealizationContext`](#max.tensor.RealizationContext).
**Parameters:**
* value ([BufferValue](graph/BufferValue.md#max.graph.BufferValue) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue))
* ctx ([RealizationContext](#max.tensor.RealizationContext))
### `ctx` {#max.tensor.RealizationState.ctx}
> ctx: [RealizationContext](#max.tensor.RealizationContext)
The realization context used to create this tensor. This context
is responsible for realizing the tensor to a real value.
### `value` {#max.tensor.RealizationState.value}
> value: [BufferValue](graph/BufferValue.md#max.graph.BufferValue) | [TensorValue](graph/TensorValue.md#max.graph.TensorValue)
The symbolic value representing the computation backing this tensor.
## `Tensor` {#max.tensor.Tensor}
> class max.tensor.Tensor(\*, storage=None, state=None)
A multi-dimensional array with eager execution and automatic compilation.
The Tensor class provides a high-level interface for numerical computations
with automatic compilation and optimization via the MAX runtime. Operations
on tensors execute eagerly while benefiting from lazy evaluation and
graph-based optimizations behind the scenes.
**Key Features:**
* **Eager execution**: Operations execute immediately with automatic compilation.
* **Lazy evaluation**: Computation may be deferred until results are needed.
* **High performance**: Uses the Mojo compiler and optimized kernels.
* **Familiar API**: Supports common array operations and indexing.
* **Device flexibility**: Works seamlessly across CPU and accelerators.
**Creating Tensors:**
Create tensors using factory methods like [`ones()`](#max.tensor.Tensor.ones), [`zeros()`](#max.tensor.Tensor.zeros),
[`constant()`](#max.tensor.Tensor.constant), [`arange()`](#max.tensor.Tensor.arange), or from other array libraries via
[`from_dlpack()`](#max.tensor.Tensor.from_dlpack).
```python
from max import tensor
from max.dtype import DType
# Create tensors with factory methods
x = tensor.Tensor.ones((2, 3), dtype=DType.float32)
y = tensor.Tensor.zeros((2, 3), dtype=DType.float32)
# Perform operations
result = x + y # Eager execution with automatic compilation
# Access values
print(result.shape) # (2, 3)
print(result.dtype) # DType.float32
```
**Implementation Notes:**
Tensors use lazy evaluation internally - they don’t always hold concrete
data in memory. A tensor may be “unrealized” (not yet computed) until its
value is actually needed (e.g., when converting to other formats or calling
[`item()`](#max.tensor.Tensor.item)). This allows the runtime to optimize sequences of
operations efficiently.
Operations on tensors build a computation graph behind the scenes, which is
compiled and executed when needed. All illegal operations fail immediately
with clear error messages, ensuring a smooth development experience.
:::note Note
The lazy evaluation model and JIT compilation introduce compilation overhead
on first execution of operations. This results in higher latency for
interactive operations compared to eager frameworks like NumPy or PyTorch,
particularly when materializing tensor values (e.g., printing or converting
to other formats). Subsequent operations on similar shapes and dtypes reuse
compiled kernels for improved performance.
:::
**Interoperability:**
Tensors support the DLPack protocol for zero-copy data exchange with NumPy,
PyTorch, JAX, and other array libraries. Use [`from_dlpack()`](#max.tensor.Tensor.from_dlpack) to import
arrays and standard DLPack conversion for export.
**Parameters:**
* storage ([Buffer](driver.md#max.driver.Buffer) | None)
* state ([RealizationState](#max.tensor.RealizationState) | None)
### `T` {#max.tensor.Tensor.T}
> property T: [Tensor](#max.tensor.Tensor)
Returns a tensor with the last two dimensions transposed.
This is equivalent to calling `transpose(-1, -2)`, which swaps
the last two dimensions of the tensor. For a 2D matrix, this produces
the standard matrix transpose.
```python
from max.tensor import Tensor
from max.dtype import DType
# Create a 2x3 matrix
x = Tensor.constant([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3)]
# Use .T property (equivalent to transpose(-1, -2))
y = x.T
print(f"Transposed shape: {y.shape}")
# Output: Transposed shape: [Dim(3), Dim(2)]
print(y)
```
**Returns:**
A tensor with the last two dimensions transposed.
### `arange()` {#max.tensor.Tensor.arange}
> classmethod arange(start=0, stop=None, step=1, \*, dtype=None, device=None)
Creates a tensor with evenly spaced values within a given interval.
Returns a new 1D tensor containing a sequence of values starting from
`start` (inclusive) and ending before `stop` (exclusive), with values
spaced by `step`. This is similar to Python’s built-in `range()`
function and NumPy’s `arange()`.
```python
from max import tensor
from max.dtype import DType
# Create a range from 0 to 10 (exclusive)
x = tensor.Tensor.arange(10)
# Result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# Create a range from 5 to 15 with step 2
y = tensor.Tensor.arange(5, 15, 2)
# Result: [5, 7, 9, 11, 13]
# Use a specific dtype
z = tensor.Tensor.arange(0, 5, dtype=DType.float32)
# Result: [0.0, 1.0, 2.0, 3.0, 4.0]
# Create a range with float step (like numpy/pytorch)
w = tensor.Tensor.arange(0.0, 1.0, 0.2, dtype=DType.float32)
# Result: [0.0, 0.2, 0.4, 0.6, 0.8]
# Create a descending range with negative step
v = tensor.Tensor.arange(5, 0, -1, dtype=DType.float32)
# Result: [5.0, 4.0, 3.0, 2.0, 1.0]
```
**Parameters:**
* start (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The starting value of the sequence. If `stop` is not provided,
this becomes the `stop` value and `start` defaults to 0.
* stop (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – The end value of the sequence (exclusive). If not specified,
the sequence ends at `start` and begins at 0.
* step (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray)) – The spacing between values in the sequence. Must be non-zero.
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified,
defaults to `DType.float32` for CPU devices and
`DType.bfloat16` for accelerator devices.
* device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not
specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A 1D tensor containing the evenly spaced values.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `argmax()` {#max.tensor.Tensor.argmax}
> argmax(axis=-1)
Finds the indices of the maximum values along an axis.
Returns a tensor containing the indices of the maximum values along
the specified axis. This is useful for finding the position of the
largest element, such as determining predicted classes in classification.
```python
from max import tensor
from max.dtype import DType
# Create a 2x4 tensor
x = tensor.Tensor.constant(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32
)
# Find argmax along last axis (within each row)
indices = x.argmax(axis=-1)
# Result: [1, 2] (index 1 in first row, index 2 in second row)
# Find argmax over all elements
index = x.argmax(axis=None)
# Result: 6 (flattened index of maximum value 4.2)
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to find the maximum indices. Defaults
to -1 (the last axis). If None, finds the index of the maximum
value across all elements.
**Returns:**
A tensor containing the indices of the maximum values.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `broadcast_to()` {#max.tensor.Tensor.broadcast_to}
> broadcast\_to(shape)
Broadcasts the tensor to the specified shape.
Returns a tensor broadcast to the target shape, following NumPy
broadcasting semantics. Dimensions of size 1 in the input can be
expanded to match larger dimensions in the target shape.
This is equivalent to PyTorch’s `torch.broadcast_to()` and
`torch.Tensor.expand()`.
```python
from max import tensor
from max.dtype import DType
# Create a tensor with shape (3, 1)
x = tensor.Tensor.ones([3, 1], dtype=DType.float32)
# Broadcast to (3, 4) - expands the second dimension
y = x.broadcast_to([3, 4])
print(y.shape) # (3, 4)
# Add a new leading dimension
w = x.broadcast_to([2, 3, 1])
print(w.shape) # (2, 3, 1)
```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The target shape. Each dimension must either match the input
dimension or be broadcastable from size 1.
**Returns:**
A tensor broadcast to the specified shape.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `cast()` {#max.tensor.Tensor.cast}
> cast(dtype)
Casts the tensor to a different data type.
Returns a new tensor with the same values but a different data type.
This is useful for type conversions between different numeric types,
such as converting `float32` to `int32` for indexing operations or
`float32` to `bfloat16` for memory-efficient computations.
```python
from max import tensor
from max.dtype import DType
# Create a float32 tensor
x = tensor.Tensor.constant([1.7, 2.3, 3.9], dtype=DType.float32)
print(x.dtype) # DType.float32
# Cast to int32 (truncates decimal values)
y = x.cast(DType.int32)
print(y.dtype) # DType.int32
# Values: [1, 2, 3]
```
**Parameters:**
dtype ([DType](dtype.md#max.dtype.DType)) – The target data type for the tensor.
**Returns:**
A new tensor with the specified data type.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `clip()` {#max.tensor.Tensor.clip}
> clip(\*, min=None, max=None)
Clips values outside a range to the boundaries of the range.
```python
from max import tensor
# Create a 2x4 tensor
x = tensor.Tensor.constant(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]]
)
# Find max along last axis (within each row)
clipped_above = x.clip(max=3.)
# Result: [[1.2, 3., 2.1, 0.8], [2.3, 1.9, 3, 3.]]
clipped_below = x.clip(min=3.)
# Result: [[3., 3.5, 3., 3.], [3., 3., 4.2, 3.]]
```
**Parameters:**
* min (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – The minimum value of the range. If not specified, do not
clip values for being too small.
* max (Value\[TensorType] | [TensorValue](graph/TensorValue.md#max.graph.TensorValue) | [Shape](graph/shape.md#max.graph.shape.Shape) | [Dim](graph/dim.md#max.graph.dim.Dim) | HasTensorValue | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [DLPackArray](driver.md#max.driver.DLPackArray) | None) – The maximum value of the range. If not specified, do not
clip values for being too large.
**Returns:**
A tensor containing the values clipped to the specified range.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `constant()` {#max.tensor.Tensor.constant}
> classmethod constant(value, \*, dtype=None, device=None)
Creates a constant tensor from a scalar, array, or nested list.
Constructs a tensor with constant values that can be a scalar, a nested
Python list, or a DLPack-compatible array. The shape is automatically
inferred from the input data structure.
```python
from max import tensor
from max.dtype import DType
# Create from scalar
x = tensor.Tensor.constant(42, dtype=DType.int32)
# Create from nested list
y = tensor.Tensor.constant([[1.0, 2.0], [3.0, 4.0]])
# Create from NumPy array
import numpy as np
z = tensor.Tensor.constant(np.array([1, 2, 3]))
```
**Parameters:**
* value ([DLPackArray](driver.md#max.driver.DLPackArray) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[Number | NestedArray]] | [float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The constant value for the tensor. Can be a scalar number,
a nested Python list, or any DLPack-compatible array.
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified,
defaults to `DType.float32` for CPU devices and
`DType.bfloat16` for accelerator devices.
* device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not
specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor containing the constant value(s).
**Return type:**
[Tensor](#max.tensor.Tensor)
### `device` {#max.tensor.Tensor.device}
> property device: [Device](driver.md#max.driver.Device)
Gets the device where the tensor is stored.
Returns the device (CPU or accelerator) where the tensor’s data is
located.
**Returns:**
The device where the tensor is stored.
**Return type:**
[Device](driver.md#max.driver.Device)
### `driver_tensor` {#max.tensor.Tensor.driver_tensor}
> property driver\_tensor: [Buffer](driver.md#max.driver.Buffer)
A pointer to the underlying memory.
Raises if the tensor is unrealized.
### `dtype` {#max.tensor.Tensor.dtype}
> property dtype: [DType](dtype.md#max.dtype.DType)
Gets the data type of the tensor elements.
Returns the data type (dtype) of the elements stored in the tensor,
such as `float32`, `int32`, or `bfloat16`.
**Returns:**
The data type of the tensor elements.
**Return type:**
[DType](dtype.md#max.dtype.DType)
### `from_dlpack()` {#max.tensor.Tensor.from_dlpack}
> classmethod from\_dlpack(array)
Creates a tensor from a DLPack array.
Constructs a tensor by importing data from any object that supports
the DLPack protocol (such as NumPy arrays and PyTorch tensors).
This enables zero-copy interoperability with other array libraries.
```python
import numpy as np
from max import tensor
# Create a NumPy array
np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
# Convert to MAX tensor via DLPack
x = tensor.Tensor.from_dlpack(np_array)
```
**Parameters:**
array ([DLPackArray](driver.md#max.driver.DLPackArray)) – Any object supporting the DLPack protocol, such as NumPy
arrays, PyTorch tensors, or JAX arrays.
**Returns:**
A new tensor containing the data from the DLPack array.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `from_graph_value()` {#max.tensor.Tensor.from_graph_value}
> classmethod from\_graph\_value(value)
Creates a tensor from a graph value.
Constructs a tensor from an existing graph value, which can be either
a [`TensorValue`](graph/TensorValue.md#max.graph.TensorValue) or [`BufferValue`](graph/BufferValue.md#max.graph.BufferValue). This
is used for converting graph level values into tensor objects.
The new tensor is registered as unrealized, backed by the current
realization context.
**Parameters:**
value ([Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The graph value to wrap. Can be either a TensorValue or
BufferValue from the MAX graph API.
**Returns:**
A new tensor backed by the provided graph value.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `full()` {#max.tensor.Tensor.full}
> classmethod full(shape, value, \*, dtype=None, device=None)
Creates a tensor filled with a specified value.
Returns a new tensor with the given shape where all elements are
initialized to the specified value. This is useful for creating
tensors with uniform values other than zero or one.
```python
from max import tensor
from max.dtype import DType
# Create a 3x3 tensor filled with 7
x = tensor.Tensor.full((3, 3), value=7, dtype=DType.int32)
# Create a 2x4 tensor filled with pi
y = tensor.Tensor.full((2, 4), value=3.14159)
```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Can be a tuple of integers,
a list of integers, or any value that can be converted to a shape.
* value ([float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The scalar value to fill the tensor with.
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified,
defaults to `DType.float32` for CPU devices and
`DType.bfloat16` for accelerator devices.
* device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not
specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor with the specified shape filled with the given value.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `full_like()` {#max.tensor.Tensor.full_like}
> classmethod full\_like(input, value)
Creates a tensor filled with a value, matching a given tensor’s properties.
Returns a new tensor filled with the specified value that matches the
shape, data type, and device of the input tensor. This behaves like
NumPy’s `full_like` and PyTorch’s `full_like`.
```python
from max import tensor
from max.dtype import DType
# Create a reference tensor
ref = tensor.Tensor.ones([2, 3], dtype=DType.float32)
# Create tensor filled with 5.0 matching the reference tensor
x = tensor.Tensor.full_like(ref, value=5.0)
```
**Parameters:**
* input ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor or tensor type to match. The returned tensor will
have the same shape, dtype, and device as this input.
* value ([float](https://docs.python.org/3/library/functions.html#float) | [number](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]) – The scalar value to fill the tensor with.
**Returns:**
A new tensor filled with the specified value, matching the
properties of the input.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `item()` {#max.tensor.Tensor.item}
> item()
Gets the scalar value from a single-element tensor.
Extracts and returns the scalar value from a tensor containing exactly
one element. The tensor is realized if needed and transferred to CPU
before extracting the value.
**Returns:**
The scalar value from the tensor. The return type matches the tensor’s
dtype (e.g., float for float32, int for int32).
**Raises:**
[TypeError](https://docs.python.org/3/library/exceptions.html#TypeError) – If the tensor contains more than one element.
### `max()` {#max.tensor.Tensor.max}
> max(axis=-1)
Computes the maximum values along an axis.
Returns a tensor containing the maximum values along the specified axis.
This is useful for reduction operations and finding peak values in data.
```python
from max import tensor
from max.dtype import DType
# Create a 2x4 tensor
x = tensor.Tensor.constant(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32
)
# Find max along last axis (within each row)
row_max = x.max(axis=-1)
# Result: [3.5, 4.2]
# Find max along first axis (within each column)
col_max = x.max(axis=0)
# Result: [2.3, 3.5, 4.2, 3.1]
# Find max over all elements
overall_max = x.max(axis=None)
# Result: 4.2 (maximum value across all elements)
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the maximum. Defaults to -1
(the last axis). If None, computes the maximum across all elements.
**Returns:**
A tensor containing the maximum values along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `mean()` {#max.tensor.Tensor.mean}
> mean(axis=-1)
Computes the mean values along an axis.
Returns a tensor containing the arithmetic mean of values along the
specified axis. This is useful for computing averages, normalizing data,
or aggregating statistics.
```python
from max import tensor
from max.dtype import DType
# Create a 2x4 tensor
x = tensor.Tensor.constant(
[[2.0, 4.0, 6.0, 8.0], [1.0, 3.0, 5.0, 7.0]], dtype=DType.float32
)
# Compute mean along last axis (within each row)
row_mean = x.mean(axis=-1)
# Result: [5.0, 4.0] (mean of each row)
# Compute mean along first axis (within each column)
col_mean = x.mean(axis=0)
# Result: [1.5, 3.5, 5.5, 7.5] (mean of each column)
# Compute mean over all elements
overall_mean = x.mean(axis=None)
# Result: 4.5 (mean of all elements)
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the mean. Defaults to -1
(the last axis). If None, computes the mean across all elements.
**Returns:**
A tensor containing the mean values along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `min()` {#max.tensor.Tensor.min}
> min(axis=-1)
Computes the minimum values along an axis.
Returns a tensor containing the minimum values along the specified axis.
This is useful for reduction operations and finding the smallest values
in data.
```python
from max import tensor
from max.dtype import DType
# Create a 2x4 tensor
x = tensor.Tensor.constant(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32
)
# Find min along last axis (within each row)
row_min = x.min(axis=-1)
# Result: [0.8, 1.9]
# Find min along first axis (within each column)
col_min = x.min(axis=0)
# Result: [1.2, 1.9, 2.1, 0.8]
# Find min over all elements
overall_min = x.min(axis=None)
# Result: 0.8 (minimum value across all elements)
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the minimum. Defaults to -1
(the last axis). If None, computes the minimum across all elements.
**Returns:**
A tensor containing the minimum values along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `num_elements()` {#max.tensor.Tensor.num_elements}
> num\_elements()
Gets the total number of elements in the tensor.
Computes the product of all dimensions in the tensor’s shape to
determine the total number of elements.
### `ones()` {#max.tensor.Tensor.ones}
> classmethod ones(shape, \*, dtype=None, device=None)
Creates a tensor filled with ones.
Returns a new tensor with the specified shape where all elements are
initialized to one. The tensor is created with eager execution and
automatic compilation.
```python
from max import tensor
from max.driver import CPU
from max.dtype import DType
# Create a 2x3 tensor of ones
x = tensor.Tensor.ones((2, 3), dtype=DType.float32, device=CPU())
# Result: [[1.0, 1.0, 1.0],
# [1.0, 1.0, 1.0]]
# Create a 1D tensor using default dtype and device
y = tensor.Tensor.ones((5,))
```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Can be a tuple of integers,
a list of integers, or any value that can be converted to a shape.
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified,
defaults to `DType.float32` for CPU devices and
`DType.bfloat16` for accelerator devices.
* device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not
specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor with the specified shape filled with ones.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `ones_like()` {#max.tensor.Tensor.ones_like}
> classmethod ones\_like(input)
Creates a tensor of ones matching a given tensor’s properties.
Returns a new tensor filled with ones that matches the shape, data type,
and device of the input tensor. This behaves like NumPy’s `ones_like`
and PyTorch’s `ones_like`.
```python
from max import tensor
from max.dtype import DType
# Create a reference tensor
ref = tensor.Tensor.zeros([3, 4], dtype=DType.float32)
# Create ones tensor matching the reference tensor
x = tensor.Tensor.ones_like(ref)
# Result: 3x4 tensor of ones with dtype float32
```
**Parameters:**
input ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor or tensor type to match. The returned tensor will
have the same shape, dtype, and device as this input.
**Returns:**
A new tensor filled with ones matching the properties of the
input.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `permute()` {#max.tensor.Tensor.permute}
> permute(dims)
Permutes the dimensions of the tensor.
Returns a tensor with its dimensions reordered according to the
specified permutation. This is useful for changing the layout of
multi-dimensional data, such as converting between different tensor
layout conventions (e.g., from `[batch, channels, height, width]`
to `[batch, height, width, channels]`).
```python
from max.tensor import Tensor
from max.dtype import DType
# Create a 3D tensor (batch_size=2, channels=3, length=4)
x = Tensor.constant([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]],
dtype=DType.int32)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3), Dim(4)]
# Rearrange to (batch, length, channels)
y = x.permute([0, 2, 1])
print(f"Permuted shape: {y.shape}")
# Output: Permuted shape: [Dim(2), Dim(4), Dim(3)]
```
**Parameters:**
dims ([list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – A list specifying the new order of dimensions. For example,
`[2, 0, 1]` moves dimension 2 to position 0, dimension 0 to
position 1, and dimension 1 to position 2.
**Returns:**
A tensor with permuted dimensions.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `range_like()` {#max.tensor.Tensor.range_like}
> classmethod range\_like(type)
Creates a range tensor matching a given type’s properties.
Returns a new tensor containing sequential indices along the last
dimension, broadcasted to match the shape of the specified tensor type.
Each row (along the last dimension) contains values from 0 to the
dimension size minus one. This is useful for creating position indices
or coordinate tensors.
```python
from max import tensor
from max.graph import TensorType
from max.driver import CPU
from max.dtype import DType
# Create a reference tensor type with shape (2, 4)
ref_type = TensorType(DType.int32, (2, 4), device=CPU())
# Create range tensor matching the reference type
x = tensor.Tensor.range_like(ref_type)
# Result: [[0, 1, 2, 3],
# [0, 1, 2, 3]]
```
**Parameters:**
type ([TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor type to match. The returned tensor will have the
same shape, dtype, and device as this type, with values
representing indices along the last dimension.
**Returns:**
A new tensor with sequential indices broadcasted to match
the input type’s shape.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `rank` {#max.tensor.Tensor.rank}
> property rank: [int](https://docs.python.org/3/library/functions.html#int)
Gets the number of dimensions in the tensor.
Returns the rank (number of dimensions) of the tensor. For example,
a scalar has rank 0, a vector has rank 1, and a matrix has rank 2.
### `real` {#max.tensor.Tensor.real}
> property real: [bool](https://docs.python.org/3/library/functions.html#bool)
### `realize` {#max.tensor.Tensor.realize}
> property realize: [Tensor](#max.tensor.Tensor)
Force the tensor to realize if it is not already.
### `reshape()` {#max.tensor.Tensor.reshape}
> reshape(shape)
Reshapes the tensor to a new shape.
Returns a tensor with the same data but a different shape. The total
number of elements must remain the same. This is useful for changing
tensor dimensions for different operations, such as flattening a
multi-dimensional tensor or converting a 1D tensor into a matrix.
```python
from max import tensor
from max.dtype import DType
# Create a 2x3 tensor
x = tensor.Tensor.constant([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(x.shape) # (2, 3)
# Flatten to 1D
y = x.reshape((6,))
print(y.shape) # (6,)
# Values: [1, 2, 3, 4, 5, 6]
```
**Parameters:**
shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The desired output shape. Can be a tuple or list of integers.
The total number of elements must equal the original tensor’s
element count.
**Returns:**
A reshaped tensor with the specified shape.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `shape` {#max.tensor.Tensor.shape}
> property shape: [Shape](graph/shape.md#max.graph.shape.Shape)
Gets the shape of the tensor.
Returns the dimensions of the tensor as a shape object.
**Returns:**
The shape of the tensor.
**Return type:**
[Shape](graph/shape.md#max.graph.shape.Shape)
### `split()` {#max.tensor.Tensor.split}
> split(split\_size\_or\_sections, axis=0)
Splits the tensor into multiple tensors along a given dimension.
This method supports two modes, matching PyTorch’s behavior:
* If `split_size_or_sections` is an **int**, splits into chunks of
that size (the last chunk may be smaller if not evenly divisible).
* If `split_size_or_sections` is a **list of ints**, splits into
chunks with exactly those sizes (must sum to the dimension size).
```python
from max import tensor
from max.dtype import DType
# Create a 10x4 tensor
x = tensor.Tensor.ones([10, 4], dtype=DType.float32)
# Split into chunks of size 3 (last chunk is size 1)
chunks = x.split(3, axis=0)
# Result: 4 tensors with shapes [3,4], [3,4], [3,4], [1,4]
# Split into exact sizes
chunks = x.split([2, 3, 5], axis=0)
# Result: 3 tensors with shapes [2,4], [3,4], [5,4]
```
**Parameters:**
* split\_size\_or\_sections ([int](https://docs.python.org/3/library/functions.html#int) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]) – Either an int (chunk size) or a list of
ints (exact sizes for each output tensor).
* axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension along which to split. Defaults to 0.
### `squeeze()` {#max.tensor.Tensor.squeeze}
> squeeze(axis)
Removes a size-1 dimension from the tensor.
Returns a tensor with the specified size-1 dimension removed. This is
useful for removing singleton dimensions from tensors after operations
that may have added them.
```python
from max import tensor
from max.dtype import DType
# Create a tensor with a size-1 dimension
x = tensor.Tensor.ones([4, 1, 6], dtype=DType.float32)
print(x.shape) # (4, 1, 6)
# Squeeze out the size-1 dimension
y = x.squeeze(axis=1)
print(y.shape) # (4, 6)
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The dimension to remove from the tensor’s shape. If negative,
this indexes from the end of the tensor. The dimension at this
axis must have size 1.
**Returns:**
A tensor with the specified dimension removed.
**Return type:**
[Tensor](#max.tensor.Tensor)
**Raises:**
[ValueError](https://docs.python.org/3/library/exceptions.html#ValueError) – If the dimension at the specified axis is not size 1.
### `state` {#max.tensor.Tensor.state}
> state: [RealizationState](#max.tensor.RealizationState) | [None](https://docs.python.org/3/library/constants.html#None)
State for realizing an unrealized tensor.
### `storage` {#max.tensor.Tensor.storage}
> storage: [Buffer](driver.md#max.driver.Buffer) | [None](https://docs.python.org/3/library/constants.html#None)
Underlying memory for a realized tensor.
If the tensor is used in any mutating operations that have
not been realized, this holds the state before any updates.
### `sum()` {#max.tensor.Tensor.sum}
> sum(axis=-1)
Computes the sum of values along an axis.
Returns a tensor containing the sum of values along the specified axis.
This is a fundamental reduction operation used for aggregating data,
computing totals, and implementing other operations like mean.
```python
from max import tensor
from max.dtype import DType
# Create a 2x3 tensor
x = tensor.Tensor.constant(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=DType.float32
)
# Sum along last axis (within each row)
row_sum = x.sum(axis=-1)
# Result: [6.0, 15.0] (sum of each row)
# Sum along first axis (within each column)
col_sum = x.sum(axis=0)
# Result: [5.0, 7.0, 9.0] (sum of each column)
# Sum over all elements
total = x.sum(axis=None)
# Result: 21.0 (sum of all elements)
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int) | None) – The axis along which to compute the sum. Defaults to -1
(the last axis). If None, computes the sum across all elements.
**Returns:**
A tensor containing the sum along the specified axis.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `to()` {#max.tensor.Tensor.to}
> to(device)
Transfers the tensor to a different device.
Creates a new tensor with the same data on the specified device. This
allows moving tensors between CPU and accelerators or between different
accelerator devices.
```python
from max import tensor
from max.driver import CPU, Accelerator
# Create a tensor on CPU
x = tensor.Tensor.ones((2, 3), device=CPU())
print(x.device) # CPU
# Transfer to accelerator
y = x.to(Accelerator())
print(y.device) # Accelerator(0)
```
**Parameters:**
device ([Device](driver.md#max.driver.Device)) – The target device for the tensor.
**Returns:**
A new tensor with the same data on the specified device.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `transpose()` {#max.tensor.Tensor.transpose}
> transpose(dim1, dim2)
Returns a tensor that is a transposed version of input.
The given dimensions `dim1` and `dim2` are swapped.
```python
from max.tensor import Tensor
from max.dtype import DType
# Create a 2x3 matrix
x = Tensor.constant([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3)]
print(x)
# Transpose dimensions 0 and 1 to get a 3x2 matrix
y = x.transpose(0, 1)
print(f"Transposed shape: {y.shape}")
# Output: Transposed shape: [Dim(3), Dim(2)]
print(y)
```
**Parameters:**
* dim1 ([int](https://docs.python.org/3/library/functions.html#int)) – The first dimension to be transposed.
* dim2 ([int](https://docs.python.org/3/library/functions.html#int)) – The second dimension to be transposed.
**Returns:**
A tensor with dimensions `dim1` and `dim2` swapped.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `type` {#max.tensor.Tensor.type}
> property type: [TensorType](graph/type.md#max.graph.type.TensorType)
Gets the tensor type information.
Returns the type information for the tensor, including shape, dtype,
and device. If the underlying value is a buffer type, it’s converted
to a tensor type.
### `unsqueeze()` {#max.tensor.Tensor.unsqueeze}
> unsqueeze(axis)
Inserts a size-1 dimension into the tensor.
Returns a tensor with a new size-1 dimension inserted at the specified
position. This is the inverse of [`squeeze()`](#max.tensor.Tensor.squeeze) and is useful for
adding dimensions needed for broadcasting or matrix operations.
```python
from max import tensor
from max.dtype import DType
# Create a 1D tensor
x = tensor.Tensor.constant([1.0, 2.0, 3.0], dtype=DType.float32)
print(x.shape) # (3,)
# Add dimension at the end
y = x.unsqueeze(axis=-1)
print(y.shape) # (3, 1)
# Add dimension at the beginning
z = x.unsqueeze(axis=0)
print(z.shape) # (1, 3)
```
**Parameters:**
axis ([int](https://docs.python.org/3/library/functions.html#int)) – The index at which to insert the new dimension. If negative,
indexes relative to 1 plus the rank of the tensor. For example,
`axis=-1` adds a dimension at the end.
**Returns:**
A tensor with an additional size-1 dimension.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `zeros()` {#max.tensor.Tensor.zeros}
> classmethod zeros(shape, \*, dtype=None, device=None)
Creates a tensor filled with zeros.
Returns a new tensor with the specified shape where all elements are
initialized to zero. The tensor is created with eager execution and
automatic compilation.
```python
from max import tensor
from max.driver import CPU
from max.dtype import DType
# Create a 2x3 tensor of zeros
x = tensor.Tensor.zeros((2, 3), dtype=DType.float32, device=CPU())
# Result: [[0.0, 0.0, 0.0],
# [0.0, 0.0, 0.0]]
# Create a 1D tensor using default dtype and device
y = tensor.Tensor.zeros((5,))
```
**Parameters:**
* shape ([Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) | [Dim](graph/dim.md#max.graph.dim.Dim) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]]) – The shape of the output tensor. Can be a tuple of integers,
a list of integers, or any value that can be converted to a shape.
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type for the tensor elements. If not specified,
defaults to `DType.float32` for CPU devices and
`DType.bfloat16` for accelerator devices.
* device ([Device](driver.md#max.driver.Device) | None) – The device where the tensor will be allocated. If not
specified, defaults to an accelerator if available, otherwise CPU.
**Returns:**
A new tensor with the specified shape filled with zeros.
**Return type:**
[Tensor](#max.tensor.Tensor)
### `zeros_like()` {#max.tensor.Tensor.zeros_like}
> classmethod zeros\_like(input)
Creates a tensor of zeros matching a given tensor’s properties.
Returns a new tensor filled with zeros that matches the shape, data type,
and device of the input tensor. This behaves like NumPy’s `zeros_like`
and PyTorch’s `zeros_like`.
```python
from max import tensor
from max.dtype import DType
# Create a reference tensor
ref = tensor.Tensor.ones([3, 4], dtype=DType.float32)
# Create zeros tensor matching the reference tensor
x = tensor.Tensor.zeros_like(ref)
# Result: 3x4 tensor of zeros with dtype float32
```
**Parameters:**
input ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType)) – The tensor or tensor type to match. The returned tensor will
have the same shape, dtype, and device as this input.
**Returns:**
A new tensor filled with zeros matching the properties of the
input.
**Return type:**
[Tensor](#max.tensor.Tensor)
## `current_realization_context()` {#max.tensor.current_realization_context}
> max.tensor.current\_realization\_context()
Return a value for the context variable for the current context.
If there is no value for the variable in the current context, the method will:
: \* return the value of the default argument of the method, if provided; or
* return the default value for the context variable, if it was created
with one; or
* raise a LookupError.
## `default_device()` {#max.tensor.default_device}
> max.tensor.default\_device(device)
Context manager for setting the default device for tensor creation.
Sets the default device used for tensor creation within the context. All
tensors created inside the context block without an explicit device
parameter will use this device.
```python
from max import tensor
from max.driver import CPU
# Use CPU as default device in this context
with tensor.default_device(CPU()):
x = tensor.Tensor.ones((2, 3)) # Created on CPU
y = tensor.Tensor.zeros((2, 3)) # Also on CPU
```
**Parameters:**
device ([Device](driver.md#max.driver.Device) | [DeviceRef](graph/type.md#max.graph.type.DeviceRef)) – The device to use as the default for tensor creation within
the context.
**Returns:**
A context manager that sets the default device.
## `default_dtype()` {#max.tensor.default_dtype}
> max.tensor.default\_dtype(dtype)
Context manager for setting the default dtype for tensor creation.
Sets the default data type used for tensor creation within the context. All
tensors created inside the context block without an explicit dtype parameter
will use this data type.
```python
from max import tensor
from max.dtype import DType
# Use int32 as default dtype in this context
with tensor.default_dtype(DType.int32):
x = tensor.Tensor.ones((2, 3)) # Created with int32
y = tensor.Tensor.zeros((2, 3)) # Also int32
```
**Parameters:**
dtype ([DType](dtype.md#max.dtype.DType)) – The data type to use as the default for tensor creation within
the context.
**Returns:**
A context manager that sets the default dtype.
## `defaults()` {#max.tensor.defaults}
> max.tensor.defaults(dtype=None, device=None)
Gets the default dtype and device for tensor creation.
Returns a tuple containing the dtype and device to use for tensor creation,
applying defaults when values are not specified. If no dtype is provided,
defaults to `DType.float32` for CPU and `DType.bfloat16` for
accelerators. If no device is provided, defaults to an accelerator if
available, otherwise CPU.
**Parameters:**
* dtype ([DType](dtype.md#max.dtype.DType) | None) – The data type to use. If not specified, a default dtype based
on the device is returned.
* device ([Device](driver.md#max.driver.Device) | None) – The device to use. If not specified, defaults to an available
accelerator or CPU.
## `defaults_like()` {#max.tensor.defaults_like}
> max.tensor.defaults\_like(like)
Context manager setting the default dtype and device for tensor creation.
Sets the default data type and device used for tensor creation within the
context. All tensors created inside the context block without explicit
dtypes or devices will use these parameters.
```python
from max import tensor
from max.driver import CPU
from max.dtype import DType
x = Tensor.zeros([1], dtype=DType.int32, device=CPU())
# Use int32 as default dtype in this context
with tensor.defaults_like(x):
y = tensor.Tensor.zeros((2, 3)) # int32, cpu
z = tensor.Tensor.zeros((2, 3), dtype=DType.float32) # float32, cpu
```
**Parameters:**
* tensor – A tensor to use as the default dtype and device for the context.
* like ([Tensor](#max.tensor.Tensor) | [TensorType](graph/type.md#max.graph.type.TensorType))
**Returns:**
A context manager that sets the default dtype and device.
## `realization_context()` {#max.tensor.realization_context}
> max.tensor.realization\_context(ctx)
Sets the current realization context, within a context manager.
New tensors created within this block will use the given realization
context to execute.
See [`RealizationContext`](#max.tensor.RealizationContext).
**Parameters:**
ctx ([RealizationContext](#max.tensor.RealizationContext)) – The realization context to set as the current context.
**Returns:**
A context manager. When the context manager is entered, it will
set ctx as the current realization context. When exited the
current realization context will be reset to its previous value.
---
## torch
## `CustomOpLibrary` {#max.torch.CustomOpLibrary}
> class max.torch.CustomOpLibrary(kernel\_library)
A PyTorch interface to custom operations implemented in Mojo.
This API allows for easy passing of PyTorch data as
`torch.Tensor` values to the corresponding custom op. `CustomOpLibrary`
handles the compilation of the Mojo custom ops and marshalling of data between
PyTorch and the executable Mojo code.
For example, consider a grayscale operation implemented in Mojo:
```mojo title="my_library/grayscale.mojo"
@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...
```
You can then use `CustomOpLibrary` to invoke the Mojo operation like so:
```python
import torch
from max.torch import CustomOpLibrary
op_library = CustomOpLibrary("my_library")
grayscale_op = op_library.grayscale
def grayscale(pic: torch.Tensor) -> torch.Tensor:
result = pic.new_empty(pic.shape[:-1])
grayscale_op(result, pic)
return result
img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)
```
The custom operation produced by `op_library.` will have the
same interface as the backing Mojo operation. Each `InputTensor` or
`OutputTensor` argument corresponds to a
[`torch.Tensor`](https://docs.pytorch.org/docs/stable/tensors.html#tensor-class-reference)
value in Python. Each argument corresponding to an `OutputTensor` in the
Mojo operation will be modified in-place.
For more information, see the [custom ops for PyTorch](/max/tutorials/custom-kernels-pytorch) tutorial.
**Parameters:**
kernel\_library (Path | [KernelLibrary](graph/KernelLibrary.md#max.graph.KernelLibrary)) – The path to a `.mojo` file or a `.mojopkg` with
your custom op kernels, or the corresponding library object.
## `graph_op()` {#max.torch.graph_op}
> max.torch.graph\_op(fn=None, name=None, kernel\_library=None, input\_types=None, output\_types=None, num\_outputs=None)
A decorator to create PyTorch custom operations using MAX graph operations.
This decorator allows you to define larger graphs using [MAX graph
ops](/max/api/python/graph/ops) or the MAX `nn` modules and
call them with PyTorch tensors, or integrate them into PyTorch modules.
These custom ops can be called eagerly, and support compilation with
`torch.compile` and the Inductor backend.
The resulting custom operation uses destination-passing style, where output
tensors are passed as the first arguments and modified in-place. This
allows PyTorch to manage the memory and streams of the output tensors.
Tensors internal to the computation are managed via MAX’s graph compiler
and memory planning.
The default behavior is to JIT-compile for the specific input and output
shapes needed. If you are passing variable-sized inputs, for instance a
batch size or sequence length which may take on many different values
between calls, you should specify this dimension as a symbolic dimension
through `input_types` and `output_types`. Otherwise you will
end up compiling specialized graphs for each possible variation of
inputs, which may use a lot of memory.
If neither output\_types nor num\_outputs is specified, default to 1
output.
For example to create a functional-style PyTorch op backed by MAX:
```python title="grayscale.py"
import torch
import numpy as np
import max.torch
from max.dtype import DType
from max.graph import ops
@max.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue):
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)
@torch.compile
def grayscale(pic: torch.Tensor):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output
device = "cuda" if torch.cuda.is_available() else "cpu"
img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)
print(f"Input shape: {img.shape}")
print(f"Output shape: {result.shape}")
print("Grayscale conversion completed successfully!")
```
**Parameters:**
* fn ([Callable](graph/ops.md#max.graph.ops.Callable)\[\[...], [Iterable](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)]] | [Value](graph/Value.md#max.graph.Value)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any)] | None] | None) – The function to decorate. If None, returns a decorator.
* name ([str](https://docs.python.org/3/library/stdtypes.html#str) | None) – Optional name for the custom operation. Defaults to the function name.
* kernel\_library (Path | [KernelLibrary](graph/KernelLibrary.md#max.graph.KernelLibrary) | None) – Optional kernel library to use for compilation. Useful
for creating graphs with custom Mojo ops.
* input\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorType](graph/type.md#max.graph.type.TensorType)] | None) – Optional sequence of input tensor types for compilation.
If None, types are inferred from runtime arguments.
* output\_types ([Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[TensorType](graph/type.md#max.graph.type.TensorType)] | None) – Optional sequence of output tensor types for compilation.
If None, types are inferred from runtime arguments.
* num\_outputs ([int](https://docs.python.org/3/library/functions.html#int) | None) – The number of outputs of the graph. We need to know this ahead
of time to register with PyTorch before we’ve compiled the final kernels.
**Returns:**
A PyTorch custom operation that can be called with torch.Tensor arguments.
---
## What's new
Here's everything you should know about what's changed in each release.
## Nightly: v26.2
This version is still a work in progress.
See how to [install the nightly
release](/max/packages#install).
### MAX models {#26-2-models}
* Add support for Qwen/Qwen3-30B-A3B-Instruct-2507 which is a MOE model.
* Add multi-GPU tensor parallelism support for Qwen3 and Qwen3-MoE models.
* Remove legacy Gemma 3 multimodal implementation and the
`MODULAR_MAX_DISABLE_GEMMA3_VISION` environment variable.
* Implement multi-GPU support (tensor parallelism) for GPT-OSS.
### MAX framework {#26-2-max}
#### Inference server {#26-2-max-serve}
* Enabled overlap scheduling for select model architectures like
`LlamaForCausalLM_Legacy` by default. This optimization reduces CPU overhead
by overlapping python host code with GPU kernel execution. This optimization
is currently incompatible with some features such as structured outputs or cpu
models. This feature is very experimental! You can forcibly disable it via
`--no-enable-overlap-scheduler --force`.
#### Python API {#26-2-max-python}
* Keep a global MLIR context active and drop per-graph context plumbing so
algebraic dims and graph/custom op construction work without an explicit
context manager. Threadpool-backed MAX paths now scope worker-thread MLIR
usage to the default context automatically.
### MAX kernels {#26-2-max-kernels}
### Mojo language {#26-2-mojo}
For all the updates to the Mojo language, standard library, and tools,
including all GPU programming and `Layout`/`LayoutTensor` changes, see the [Mojo
changelog](/mojo/changelog)
## v26.1 (2026-01-29)
### Highlights {#26-1-highlights}
The eager-style [`Tensor`](/max/api/python/tensor#max.tensor.Tensor) and
[`Module`](/max/api/python/nn/module#max.nn.module.Module) APIs are
now the primary API for model development, providing a PyTorch-like development
experience:
```python
from max import functional as F
from max.tensor import Tensor
from max.dtype import DType
x = Tensor.constant([1.0, -2.0, 3.0, -4.0, 5.0], dtype=DType.float16)
y = F.relu(x)
print(y)
# Tensor([1 0 3 0 5], dtype=DType.float16, device=Device(type=gpu,id=0))
```
If you want explicit control over the graph structure, you can
still build models with the [`Graph`](/max/api/python/graph/Graph) APIs.
For more details, see the [model developer guide](/max/develop/).
### Documentation {#26-1-docs}
* The fully refactored [MAX LLM book](https://llm.modular.com/) is now designed
so the code you write in each exercise incrementally builds upon the last one,
until you've built an executable GPT-2 model with the MAX Python API.
* New model developer guide introduces [eager-style
programming](/max/develop/), [tensor APIs](/max/develop/tensors), and [data
types](/max/develop/dtypes). Much more is coming soon.
* New guide to [profile MAX on GPUs with `nsys`](/max/gpu-system-profiling).
* Extended [documentation for
`kbench`](https://github.com/modular/modular/tree/main/max/kernels/benchmarks/autotune#kbench-a-benchmarking-toolkit-for-mojo-kernels),
a Python tool to benchmark, autotune, and analyze MAX kernel performance.
### MAX models {#26-1-models}
* [Gemma3](https://builds.modular.com/models/gemma-3-it/27B) now supports
vision input (multimodal) in the 12B and 27B variants, including support for
local file paths and structured output. Learn more in the [image to text
guide](/max/inference/image-to-text).
* Added `Qwen/Qwen3-VL-4B-Instruct` and `Qwen/Qwen3-VL-2B-Instruct`
model architectures.
* Removed Llama 3.2 Vision (`Llama-3.2-11B-Vision-Instruct`) architecture support.
Use other vision models such as Pixtral, InternVL, Qwen2.5-VL, and Gemma3.
### MAX framework {#26-1-max}
* All Python wheels are now hosted at `https://whl.modular.com/nightly/simple/`.
If using `uv`, change `--index-url` to `--index`, and if using `pip`, change to
`--extra-index-url`. For precise commands, see the
[install guide](/max/packages#install).
#### Inference server {#26-1-max-serve}
* Improved scheduling to achieve higher KVCache utilization and batch sizes. By
default, MAX now schedules a context encoding (CE) request only if KVCache
memory is less than 95% full *after* allocating blocks for that request or if
no active requests exist. You can adjust this watermark value (`0.95`) with
[`--kvcache-ce-watermark`](/max/cli/serve#--kvcache-ce-watermark-kvcache_ce_watermark).
Beware that increasing it causes more preemptions.
* When running models with data-parallelism (DP), the semantics of max batch size
has changed. For example, when specifying `--data-parallel-degree 8` and
`--max-batch-size 32` it previously meant that each data-parallel replica could
have at most 4 requests for an aggregate max batch size of 32. We changed this
so that now the CLI flag specifies the max batch size per replica. This means
the aggregate max batch size of the above values is 8\*32=256 requests.
This aligns with vLLM and other inference engines.
* `--max-ce-batch-size` is now deprecated. The cap on batch size is now uniform
between context encoding and token generation phases of text generation. Use
`--max-batch-size` instead.
* The API server now returns chunked tokens from the model worker, reducing overhead
and significantly improving throughput for small models and decode-heavy
workloads.
* Server stats collection (`collect_server_stats`) is now enabled by default for
serving benchmarks.
#### `max` CLI {#26-1-max-cli}
* The `max generate` command now applies the model's chat template internally
when using `--prompt`. This more closely aligns with how users typically prompt
a model for testing and ensures special tokens are properly filtered from
output.
* Added tracing flags to `max benchmark` for `nsys` profiling:
* `--trace`: Enable tracing of the benchmark run (currently NVIDIA GPUs only)
* `--trace-file`: Path to save the trace file
* `--trace-session`: Optional session name for tracing
Requires the server to be run under `nsys launch`. Using
`--gpu-profiling detailed` is recommended.
#### Python API {#26-1-max-python}
* The eager-style [`Tensor`](/max/api/python/tensor#max.tensor.Tensor) APIs are
now the primary API for model development, providing a PyTorch-like development
experience.
We moved the eager-style tensor APIs out of `experimental` and
reorganized the `max.nn` module to make the eager module
system the primary API (`nn.module_v3` is now `nn.module`).
The previous [`max.nn`](/max/api/python/nn/) components are still available
for backward compatibility in [`max.nn.legacy`](/max/api/python/nn/legacy/).
* Renamed `max.driver.Tensor` to
[`max.driver.Buffer`](/max/api/python/driver#max.driver.Buffer) to clarify that
it represents a low-level memory buffer, not a tensor. The
[`max.tensor.Tensor`](/max/api/python/tensor#max.tensor.Tensor) class remains
the primary tensor type.
* Added `forward()` method to
[`Module`](/max/api/python/nn/module#max.nn.module.Module) to compute the
output—it behaves the same as invoking the object as a callable (the
`__call__()` method).
* `accelerator_count()` now returns a non-zero value when called on an Apple
silicon system. This means you can use this code:
```python
device = CPU() if accelerator_count() == 0 else Accelerator()
```
And it defaults to using the available Apple silicon GPU. As a consequence,
MAX graphs should in most cases be dispatched to run on Apple silicon GPUs.
Note that most MAX models do not yet work on Apple silicon GPUs due to
missing hardware-specific kernel pathways and other support, but this is an
important step towards enabling MAX more broadly on Apple silicon GPUs.
* Added `max.nn.module.rope` containing rotary embedding implementations,
[`RotaryEmbedding`](/max/api/python/nn/rope/RotaryEmbedding) and
[`TransposedRotaryEmbedding`](/max/api/python/nn/rope/TransposedRotaryEmbedding).
* Added
[`ArchConfig`](/max/api/python/pipelines/interfaces#max.pipelines.lib.interfaces.ArchConfig)
and `ArchConfigWithKVCache`. Going forward, models that register with the MAX
architecture registry must define a config that implements this protocol
* Added `ops.complex.mul` for multiplying complex-valued tensors
* Added `calculate_virtual_device_count()`, `calculate_virtual_device_count_from_cli()`,
`load_max_buffer()` to [`max.driver`](/max/api/python/driver/).
* Added [`TokenBuffer`](/max/api/python/interfaces#max.interfaces.TokenBuffer)
for token management.
* Renamed `prefill_chunk_size` to `max_batch_input_tokens`
and `max_batch_context_length` to `max_batch_total_tokens`
in [`PipelineConfig`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig)
and `TTSConfig` classes to better reflect their purpose in batch memory
management.
The corresponding CLI flags have also been renamed:
`--prefill-chunk-size` is now `--max-batch-input-tokens` and
`--max-batch-context-length` is now `--max-batch-total-tokens`.
* Fixed `max.driver.Buffer.to(stream)` to not copy (it return reference to
the same tensor) when the stream is on the same device, even for GPU-pinned
host memory.
* Removed deprecated `max.nn` convolution classes: `Conv2dV1`, `Conv1DV1`,
`Conv3DV1`. Use `Conv2d`, `Conv1D`, `Conv3D` instead.
* Removed deprecated `max.nn` layer classes: `LinearV1`, `QLinearV1`,
`GPTQLinearV1`, `MLPV1`, `EmbeddingV1`, `LayerNormV1`, `RMSNormV1`. Use
`Linear`, `GPTQLinear`, `MLP`, `Embedding`, `LayerNorm`, `RMSNorm` instead.
* Removed `max.engine.MojoValue`
* Removed the deprecated `custom_ops_path` parameter from
[`InferenceSession.load()`](/max/api/python/engine#max.engine.InferenceSession.load).
Instead use the `custom_extensions` parameter.
* Added `graph.ops.shard_and_stack()`
* Removed unused `graph.weights.PytorchWeights`
### MAX kernels {#26-1-max-kernels}
* Improved performance for Hopper Matmul when using skinny M shapes. In particular
when M is between 2 and 64, we see a significant performance boost for specific
shapes ranging between 10 - 40%.
* Added swapAB optimization to Hopper Matmul, performs B x A and does a tranposed
write to C. This helps when you need more granularity in the M dimension.
* Refined `create_stream` API: all streams are now non-blocking (`blocking`
argument has been removed). Explicitly use `DeviceEvent` and `synchronize()`
wherever necessary.
### Mojo language {#26-1-mojo}
For all the updates to the Mojo language, standard library, and tools,
including all GPU programming and `Layout`/`LayoutTensor` changes, see the [Mojo
changelog](/mojo/changelog)
## v25.7 (2025-11-20)
### Highlights {#25-7-highlights}
* The MAX Python API is now [fully open-sourced on
GitHub](https://github.com/modular/modular/tree/main/max/python/max)!
As we expand our [model
repository](https://builds.modular.com/?category=models), we're making
significant progress on these APIs to simplify the effort to build
production-ready GenAI models in Python. Some APIs are still experimental,
but you can [build an LLM with it today](https://llm.modular.com).
### Documentation {#25-7-docs}
* New online book to [build an LLM from scratch with
MAX](https://llm.modular.com), using our **experimental model APIs**. This is a
guided lesson to building GPT-2 with our Python API, explaining each component
of the transformer model along the way. Like the Python APIs, the book is a
work in progress—please [report any issues in
GitHub](https://github.com/modular/max-llm-book/issues).
* All the planned parts of [GPU Puzzles](https://puzzles.modular.com/) are now
complete! Support for Apple silicon GPUs is also making [steady
progress](https://puzzles.modular.com/howto.html#gpu-support-matrix).
* Tutorials on docs.modular.com are now integrated into the
[Guides](/max/intro) section, indicated with a book icon in the left
navigation.
* The [`max` CLI docs](/max/cli/) are now generated from [the CLI
source](https://github.com/modular/modular/blob/main/max/python/max/entrypoints/pipelines.py).
### MAX models {#25-7-models}
* Gemma3 now supports logprobs.
### MAX framework {#25-7-max}
* Added support for bfloat16 models running on GPUs with ARM-based CPU hosts,
such as Grace Hopper (GH200) and Grace Blackwell (GB200).
* Updated minimum NVIDIA GPU driver requirement to 580.
#### `max` CLI {#25-7-max-cli}
* [`max benchmark`](/max/cli/benchmark) can now run LoRA benchmarking for
supported models and target modules.
* `max benchmark --collect-gpu-stats` can now collect AMD
GPU statistics.
* `max serve --do-penalties` was renamed to `--enable-penalties` and enabled by
default. To disable penalties, you can specify
[`--no-enable-penalties`](/max/cli/serve#--enable-penalties---no-enable-penalties)
#### Python API {#25-7-max-python}
* Added support for Python 3.14.
* Removed support for Python 3.9.
* All MAX Python API modules are now **open-sourced**. In addition to those
previously released, we've added `driver`, `dtype`, `engine`, `experimental`,
`interfaces`, `kv_cache`, `mlir`, `nn`, `profiler`, `support`, `torch`, and
more [in our GitHub
repo](https://github.com/modular/modular/tree/main/max/python/max).
* Added [`max.profiler`](max/api/python/profiler) module with the
[`Tracer`](/max/api/python/profiler#max.profiler.Tracer) class to create and
manage profiling spans based on runtime conditions, and the
\[\`@traced()] decorator to profile a whole function.
* Added [`max.diagnostics.gpu`](/max/api/python/diagnostics/gpu) APIs to expose
common GPU statistics as might be reported by `nvidia-smi` or `rocm-smi`.
* Added the [`max.kv_cache`](/max/api/python/kv_cache/) package, which provides
APIs to manage key-value caches used in transformer models. Not to be confused
with the existing [`max.nn.kv_cache`](/max/api/python/nn/kv_cache/) package that
includes kernels for KV caching.
* Removed the `KVCacheManager` class and combined it with the single
[`PagedKVCacheManager`](/max/api/python/kv_cache/paged_cache/cache_manager#max.kv_cache.paged_cache.cache_manager.PagedKVCacheManager)
implementation. During merger, `prefetch()` was renamed `maybe_reserve()`.
* Added
[`NullKVCacheManager`](/max/api/python/kv_cache/null_cache_manager#max.kv_cache.NullKVCacheManager)
for compile-only mode, which avoids GPU memory allocation when compiling models
without a physical GPU present.
* Added
[`ResetPrefixCacheBackend`](/max/api/python/kv_cache/paged_cache/tp_cache_manager#max.kv_cache.paged_cache.ResetPrefixCacheBackend)
and
[`ResetPrefixCacheFrontend`](/max/api/python/kv_cache/paged_cache/tp_cache_manager#max.kv_cache.paged_cache.ResetPrefixCacheFrontend)
classes for coordinating prefix cache resets between frontend and backend
components.
* Added more APIs for text-to-speech (TTS) models such as
[`AudioGenerationInputs`](/max/api/python/interfaces#max.interfaces.AudioGenerationInputs)
and
[`AudioGenerationOutput`](/max/api/python/interfaces#max.interfaces.AudioGenerationOutput)
* Changed
[`LoRAConfig.max_num_loras`](/max/api/python/pipelines/lora_config#max.pipelines.lib.lora_config.LoRAConfig.max_num_loras)
default to `1` (was `100`).
* New [`RequestID`](/max/api/python/interfaces/#max.interfaces.RequestID) class
replaces previous type alias to provide better type safety and consistency
across the API.
* Removed `InputContext` and replaced it with the modality-output specific
[`TextGenerationContext`](/max/api/python/interfaces/#max.interfaces.TextGenerationContext)
and
[`EmbeddingsContext`](/max/api/python/interfaces/#max.interfaces.EmbeddingsContext).
* Added
[`ImageMetadata`](/max/api/python/interfaces/#max.interfaces.ImageMetadata) and
[`VLMTextGenerationContext`](/max/api/python/interfaces/#max.interfaces.VLMTextGenerationContext).
* Added [`max.nn.comm`](/max/api/python/nn/comm/) with `Allreduce` and
`Signals` for peer-to-peer communication in allreduce.
* [`ops.gather()`](/max/api/python/graph/ops#max.graph.ops.gather) no longer
has a default `axis`, it must be specified explicitly (better matching PyTorch
and NumPy).
* [`Graph.add_subgraph()`](/max/api/python/graph/Graph#max.graph.Graph.add_subgraph)
has been updated to take a `devices` argument. This allows subgraphs to take
advantage of device-aware work scheduling.
#### Mojo API {#25-7-max-mojo}
* Renamed the `tensor_internal` package to `tensor` and removed the
previous `tensor` stub—the API behaves the same but the [Mojo `tensor`
docs](/mojo/kernels/extensibility/tensor/) moved.
### Mojo language {#25-7-mojo}
For all the updates to the Mojo language, standard library, and tools,
including all GPU programming and `Layout`/`LayoutTensor` changes, see the [Mojo
changelog](/mojo/changelog).
## v25.6.1 (2025-10-10)
Fixes a latency regression due to a top-k algorithm change and a couple
other benchmarking bugs.
## v25.6 (2025-09-22)
* [Highlights](#25-6-highlights)
* [Documentation](#25-6-docs)
* [MAX models](#25-6-models)
* [MAX framework](#25-6-max)
* [Inference server](#25-6-max-serve)
* [`max` CLI](#25-6-max-cli)
* [Python API](#25-6-max-python)
* [MAX kernels](#25-6-kernels)
* [Mojo language](#25-6-mojo)
### Highlights {#25-6-highlights}
* MAX delivers **state-of-the-art performance on NVIDIA Blackwell** (B200)!
We've been describing our Blackwell bring-up over a series of blog posts, and
we recently published [Part 4: Breaking
SOTA](https://www.modular.com/blog/matrix-multiplication-on-blackwell-part-4---breaking-sota),
in which we share our latest matmul benchmarks compared to NVIDIA's cuBLAS
library.
* MAX provides **industry-leading performance on AMD MI355X**!
In a matter of weeks, we got MAX running on the brand new MI255X system and
have already produced early benchmarks that go head-to-head with Blackwell.
If you have access to an MI355X, you can try it yourself today by following
our [quickstart guide](/max/get-started).
* Benchmarking endpoints is easier than ever before the new [`max
benchmark`](/max/cli/benchmark) command, which accepts YAML
configuration files so you can easily share and reproduce your benchmarks.
### Documentation {#25-6-docs}
* Our new [quickstart guide](/max/get-started) lets you pick the model
architecture and size you want, and then shows you how to deploy it and run our
open-source benchmarking script, all from the `max` CLI.
* We updated and simplified the [benchmarking
tutorial](/max/deploy/benchmark) to use the new `max benchmark`
command.
### MAX models {#25-6-models}
* Added the
[gpt-oss](https://github.com/modular/modular/tree/modular/v25.6.0/max/pipelines/architectures/gpt_oss)
model architecture (GPU, bfloat16).
[Try GPT-OSS now](https://builds.modular.com/models/gpt-oss-20b-BF16/20B).
### MAX framework {#25-6-max}
* Added device-aware work scheduling for AsyncRT: work items can now specify a
`deviceHint` to route execution to specific worker threads based on device
affinity, improving multi-device performance.
* Improved code quality by enabling large set of RUFF lints, including
[flake8-annotations (ANN)](https://docs.astral.sh/ruff/rules/#flake8-annotations-ann)
which now enforces Python type annotations for new contributions.
#### Inference server {#25-6-max-serve}
* Added support for data parallelism in Llama models. To enable this feature,
use the `--data-parallel-degree` option:
```sh
max serve --model $MODEL_ID --data-parallel-degree 2 --devices gpu:0,1
```
* Metrics for each context encoding and token generation batch are now logged
to the console periodically. We can override the default frequency (3 seconds)
of such logs via setting the `MAX_SERVE_SCHEDULER_STATS_LOG_INTERVAL_S` flag.
For example, setting `MAX_SERVE_SCHEDULER_STATS_LOG_INTERVAL_S=0` will log
metrics for all batches.
* Improved error messages when pulling a model that requires more RAM than
what's available or when there won't be enough RAM left for the KV cache.
#### `max` CLI {#25-6-max-cli}
* Added the `max benchmark` subcommand that runs a suite of benchmarks and
collects performance metrics on a model server. This command provides
convenient packaging/installation for our open-source
[`benchmark_serving.py`](https://github.com/modular/modular/tree/main/benchmark#benchmark-max)
script and accepts all the same options.
* Added `--chat-template` to the CLI for passing a custom chat templates
defined in Jinja2 template files.
* Renamed the `--allow-safetensors-weights-float32-to-bfloat16-cast` flag to
`--allow-safetensors-weights-fp32-bf6-bidirectional-cast`, which supports
automatic bidirectional dtype casts when needed.
* The `max generate` command now supports `--top-k`, `--temperature`, and
`--seed` flags.
* Changed `--num-warmups` behavior. Previously, it ran the model on the prompt
`N` times, generating until reaching a stop condition each time. Now it runs
the model for `N` steps, generating `N` new tokens as a warmup.
* Added the `--model` option as a preferred alternative to `--model-path`. They
behave the same.
* Deprecated `--pad-to-multiple-of`.
* Removed the previously deprecated `--model-name`. Use `--served-model-name`
instead.
#### Python API {#25-6-max-python}
* Removed the previously deprecated `KVCacheStrategy.CONTINUOUS` and all
associated classes (including `ContinuousBatchingKVCacheManager`).
* Added [`ops.fence`](/max/api/python/graph/ops#max.graph.ops.fence), a pure
identity operation that prevents the async runtime from reordering operations
across it. This operation is essential for implementing cross-device
synchronization.
* Removed `PipelineConfig.max_new_tokens`. Use
[`SamplingParams.max_new_tokens`](/max/api/python/pipelines#max.pipelines.SamplingParams)
instead.
* Added
[`logits_processor`](/max/api/python/interfaces/#max.interfaces.SamplingParams.logits_processors)
to
[`SamplingParams`](/max/api/python/interfaces/#max.interfaces.SamplingParams)
for updating logits in-place during each step of token generation.
* Added `generate()` to
[`TextGenerationPipeline`](/max/api/python/pipelines/pipeline#max.pipelines.lib.pipeline.TextGenerationPipeline)
and
[`SpeculativeDecodingPipeline`](/max/api/python/pipelines#max.pipelines.SpeculativeDecodingPipeline),
a convenience method for getting text generations. `generate_async()` is
available for getting streamed outputs.
* Renamed the `target_num_new_tokens` configuration parameter to
[`prefill_chunk_size`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig.prefill_chunk_size)
in
[`PipelineConfig`](/max/api/python/pipelines/config/#max.pipelines.lib.config.PipelineConfig)
and `TTSConfig` classes to better reflect its role in chunked prefill
operations.
* Fixed [`ops.range`](/max/api/python/graph/ops#max.graph.ops.range) to respect
the `dtype` parameter when using [`Dim`](/max/api/python/graph/dim) objects as
inputs. Previously, the dtype was ignored and defaulted to int64.
* Made the `devices` argument in
[`InferenceSession()`](/max/api/python/engine#max.engine.InferenceSession)
required. To maintain the previous default behavior, use
`InferenceSession(devices=[CPU()])`.
* Added an optional `logging` argument to
[`InferenceSession()`](/max/api/python/engine#max.engine.InferenceSession).
When set to `"op"`, this option enables operation launch output to stderr.
* Added [`max.nn.lora`](/max/api/python/nn/lora), providing
Low-Rank Adaptation (LoRA) support for parameter-efficient fine-tuning of
neural network models.
* Added [`max.nn.moe`](/max/api/python/nn/moe), implementing
Mixture of Experts (MoE) layers for scalable model architectures.
* Added [`max.nn.sampling`](/max/api/python/nn/sampling),
containing advanced sampling methods including MinP and rejection sampling
techniques.
* Added [`max.nn.hooks`](/max/api/python/nn/hooks), providing
debugging and inspection hooks for neural network layers.
* Added attention submodules
[`max.nn.attention.mask_config`](/max/api/python/nn/attention/mask_config),
[`max.nn.attention.multihead_attention`](/max/api/python/nn/attention/multihead_attention),
and
[`max.nn.attention.multi_latent_attention`](/max/api/python/nn/attention/multi_latent_attention)
for comprehensive attention mechanism configuration and implementation.
* Moved some Mojo-related functionality to a new top-level `mojo` Python
namespace. Specifically, `max.mojo` (previously used for Mojo-Python interop),
some of `max.support`, and `max.entrypoints.mojo` now live under the `mojo`
namespace and are provided in the new [`mojo`
package](/mojo/manual/install#whats-included).
### MAX kernels {#25-6-kernels}
* Added a leaky ReLU activation function kernel.
* Added a specialized [RMS norm](/mojo/kernels/nn/normalization/rms_norm/)
function kernel for the common case of `cols=128`, `bfloat16`.
### Mojo language {#25-6-mojo}
For all the updates to the Mojo language, standard library, and tools,
including all GPU programming changes, see the [Mojo
changelog](/mojo/changelog).
## v25.5 (2025-08-05)
* [Highlights](#25-5-highlights)
* [Documentation](#25-5-docs)
* [MAX models](#25-5-models)
* [MAX framework](#25-5-max)
* [Inference server](#25-5-max-serve)
* [`max` CLI](#25-5-max-cli)
* [Python API](#25-5-max-python)
* [Mojo language](#25-5-mojo)
### Highlights {#25-5-highlights}
* **OpenAI-compatible batch API**: The [`/v1/batches`
API](/max/api/serve#operation/createBatch) is now available with
[Mammoth](/mammoth/).
We recently announced a [partnership with SF
Compute](https://www.modular.com/blog/sf-compute) to make this API available
through their dynamic GPU pricing marketplace. Their Large Scale Inference
Batch API looks different from the `/v1/batches` API in Mammoth because it's
a superset.
* **New `mojo` Conda package**: For Mojo-specific projects that run on CPUs and
GPUs, you can now install the bare essentials with the `mojo` Conda package
that's less than 900 MB on disk. For example, this now works:
```sh
pixi add mojo
```
The `mojo` Python package is not available for pip/uv yet.
For a complete model-development and serving toolkit, you should still install
the `modular` package (which includes `mojo` as a dependency).
* **Open-source graph APIs**: We've added the `max.graph` Python APIs to our
[GitHub
repo](https://github.com/modular/modular/tree/modular/v25.5.0/max/graph). We've
made great strides in recent months to simplify these APIs that help you build
high-performance models you can [serve with
MAX](/max/develop/serve-custom-model-architectures).
### Documentation {#25-5-docs}
* New [Serve custom model architectures
tutorial](/max/develop/serve-custom-model-architectures), with [example code
on
GitHub](https://github.com/modular/modular/tree/main/max/examples/custom-models).
* New guide for [using LoRA adapters with MAX](/max/serve/lora-adapters).
* Updated the [Deploy Llama 3 on GPU
tutorial](/max/tutorials/max-serve-local-to-cloud/) with instructions using
AMD MI300X (on Azure).
* Added [Pixi basics](/pixi), which is where we redirect all the now-removed
Magic docs (see our [announcement migrating Magic to
Pixi](https://forum.modular.com/t/migrating-from-magic-to-pixi/1530)).
### MAX models {#25-5-models}
* Added support for
[Idefics3](https://github.com/modular/modular/tree/modular/v25.5.0/max/pipelines/architectures/idefics3)
model.
### MAX framework {#25-5-max}
* Removed all `torch` package dependencies.
* Reduces the total installation size of `modular` (including
dependencies) from 2.2 GB for CPUs and 6.5 GB for GPUs **down to 1.5 GB**, for
all Python packages. Conda packages pull additional system dependencies so
sizes may vary, but one example brings the size down from 9.8 GB to 2.0 GB.
* `pip install` no longer requires the `--extra-index-url
https://download.pytorch.org/whl/cpu` option (which was to avoid installing
the GPU version of `torch` that has a lot of CUDA dependencies).
* `uv pip install` no longer requires the `--index-strategy unsafe-best-match`
option (which was to avoid package resolution issues with the above
`--extra-index-url` option).
* Removed HuggingFace fallback for model pipelines not natively supported in
MAX (`PipelineEngine.HUGGINGFACE`), because it's almost never used and it
creates significant tech debt.
#### Inference server {#25-5-max-serve}
* Added the [`/health` endpoint](/max/api/serve/#operation/health) for service
readiness checks, used by tools like lm-eval to determine when the service is
ready to accept requests.
* [Prefix caching](/max/serve/prefix-caching) now uses a Mojo token hashing
operation. Previously we used the `hash()` method from the Python stdlib.
However, this resulted in noticeable CPU overhead and reduced GPU utilization.
In this release, we migrated the token hashing operation to an accelerated Mojo
implementation.
* Re-implemented the OpenAI API's `logprobs` and `echo` request
parameters to eliminate an expensive device transfer.
The `--enable-echo` flag, which previously incurred a significant performance
penalty, is now 9-12x faster.
* Added support for `file://` URIs in image inputs for multimodal models. Local
file access is controlled via the `MAX_SERVE_ALLOWED_IMAGE_ROOTS` environment
variable, which specifies a list of allowed root directories. Files are read
asynchronously using aiofiles for better performance under high load.
* Improved [function calling](/max/serve/function-calling) (tool use) to more
reliably extract JSON tool calling responses for Llama models in an
OpenAI-compatible format.
* Switched from XGrammar to
[llguidance](https://github.com/guidance-ai/llguidance) for generating
structured output (constrained decoding).
#### `max` CLI {#25-5-max-cli}
* Added `--vision-config-overrides` CLI option to override
vision model configuration parameters. For example, to decrease InternVL's
maximum dynamic patches from 12 to 6:
```bash
max serve --model-path OpenGVLab/InternVL3-38B-Instruct \
--vision-config-overrides '{"max_dynamic_patch": 6}'
```
* Removed `--ignore-eos` CLI argument. The full set of OpenAI chat and
completion sampling parameters are now supported in the http requests. As
such, the parameter can just be set via the http payload.
#### Python API {#25-5-max-python}
* Added the [`max.interfaces`](/max/api/python/interfaces) module. This module
should serve as a relatively import free module to hold all shared interfaces
across the MAX stack. Slowly we will be moving common interfaces to this
module. So far, we've moved the following from `max.pipelines.core`:
* Moved `TextGenerationStatus`, `TextResponse`, `TextGenerationResponse`,
`InputContext`, and `PipelineTask` into `max.interfaces`.
* Moved all `TokenGeneratorRequest`-prefixed objects into `max.interfaces`
and renamed with the `TextGenerationRequest` prefix.
* Moved `TextGenerationStatus` to
[`GenerationStatus`](/max/api/python/interfaces/#max.interfaces.GenerationStatus).
* Moved `TextResponse` and `TextGenerationResponse` to
[`TextGenerationOutput`](/max/api/python/interfaces/#max.interfaces.TextGenerationOutput).
* Moved `EmbeddingsResponse` to
[`EmbeddingsOutput`](/max/api/python/interfaces#max.interfaces.EmbeddingsOutput).
* Added [`ops.scatter_nd`](/max/api/python/graph/ops/#max.graph.ops.scatter_nd)
operation for scattering updates into a tensor at specified indices.
* Added [`ops.avg_pool2d`](/max/api/python/graph/ops/#max.graph.ops.avg_pool2d)
and [`ops.max_pool2d`](/max/api/python/graph/ops/#max.graph.ops.max_pool2d).
* Added [`max.torch.graph_op`](/max/api/python/torch#max.torch.graph_op)
interface to make it simple to embed larger MAX computations and models inside
PyTorch. These can use `max.nn` modules internally and may be used within
`torch.nn` modules, allowing the use of MAX subcomponents for access to our
high performance graph compiler and Mojo kernel library.
```python
import torch
import numpy as np
import max
from max.dtype import DType
from max.graph import ops
@max.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue):
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)
@torch.compile
def grayscale(pic: torch.Tensor):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output
img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)
```
* Moved `AlgebraicDim`, `Dim`, `StaticDim`, and `SymbolicDim` out of `max.type`
and into [`max.graph.dim`](/max/api/python/graph/dim). You can still import
them directly from `max.graph`.
* Moved `Shape` out of `max.type` and into
[`max.graph.shape`](/max/api/python/graph/shape). You can still import it
directly from `max.graph`.
* Removed the ability to pass Python objects into models and have them returned
as Mojo `PythonObject` types in the kernels.
* Removed `RandomWeights`.
* Removed `Model.execute_legacy()`. Instead use the
standard [`execute()`](/max/api/python/engine#max.engine.Model.execute) or
[`__call__()`](/max/api/python/engine#max.engine.Model.__call) methods.
* Removed TorchScript-related helper functions and APIs, including support for
`.pt` TorchScript files in custom extensions.
### Mojo language {#25-5-mojo}
For all the updates to the Mojo language, standard library, and tools,
including all GPU programming changes, see the [Mojo
changelog](/mojo/changelog).
## v25.4 (2025-06-18)
* [Highlights](#25-4-highlights)
* [Documentation](#25-4-docs)
* [MAX models](#25-4-models)
* [MAX framework](#25-4-max)
* [Inference server](#25-4-max-serve)
* [`max` CLI](#25-4-max-cli)
* [Python API](#25-4-max-python)
* [Mojo API](#25-4-max-mojo)
* [Custom ops](#25-4-custom-ops)
* [GPU programming](#25-4-gpu-programming)
* [Mojo language](#25-4-mojo)
### ✨ Highlights {#25-4-highlights}
* **AMD GPUs are officially supported!**
You can now deploy MAX with acceleration on AMD MI300X and MI325X GPUs, using
the same code and container that works on NVIDIA GPUs. For the first time,
you can build portable, high-performance GenAI deployments that run
on any platform without vendor lock-in or platform-specific optimizations.
For more details, including benchmarks, see our [Modular + AMD blog
post](https://www.modular.com/blog/modular-x-amd-unleashing-ai-performance-on-amd-gpus).
* **Now accepting GPU kernel contributions**
Last month, we open-sourced the code for the CPU and GPU kernels that power
the MAX framework, and now we're accepting contributions! For information
about how to contribute and the sort of kernels most interesting to us,
see the [MAX AI kernels contributing
guide](https://github.com/modular/modular/blob/main/max/kernels/CONTRIBUTING.md).
* **Preview: Mojo interoperability from Python**
This release includes an early version of a new Python-to-Mojo
interoperability API. You can now write just the performance-critical parts
your code in Mojo and call it from Python just like you're importing another
Python library. Check out our docs to [call Mojo from
Python](/mojo/manual/python/mojo-from-python).
### Documentation {#25-4-docs}
We've redesigned [builds.modular.com](https://builds.modular.com) and
[docs.modular.com](https://docs.modular.com) with a unified top navigation bar
that so you can more easily discover all the available docs and code resources.
New docs:
* [GPU Puzzles](https://builds.modular.com/puzzles/introduction.html): Several
new puzzles, including: 1D convolution op, softmax op, attention op,
embedding op, kernel fusion, custom backward pass, GPU functional programming
patterns, and warp fundamentals.
* [Using AI coding assistants guide](/max/coding-assistants): Learn how to use
large language models (LLMs) and coding assistants (such as Cursor and Claude
Code) to accelerate your development with Modular.
* [Build an MLP block as a graph module tutorial](/max/develop/build-an-mlp-block):
Learn how to create reusable `Module` components in your MAX graphs.
* [Write custom ops for PyTorch
tutorial](/max/develop/custom-kernels-pytorch) (Beta feature): Learn to write
high-performance GPU kernels for your PyTorch models with Mojo.
* [Profile MAX kernel
performance](https://github.com/modular/modular/blob/main/max/docs/kernel-profiling.md):
Learn how to set up Nsight Compute to profile your Mojo-based kernels on NVIDIA
GPUs.
Major updates:
* [Build custom ops for GPUs tutorial](/max/develop/build-custom-ops):
Now includes how to write hardware-specific functions for CPUs and GPUs.
* [Optimize a matrix multiply custom op
tutorial](/max/develop/custom-ops-matmul): Migrated from a Recipe with
revisions to help you improve the performance of your GPU custom ops.
### MAX models {#25-4-models}
* Added the OLMo 2 model architecture
([`olmo2`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/olmo2)).
[Try OLMo 2 now](https://builds.modular.com/models/OLMo-2-1124/7B).
* Added Google's Gemma 3 multimodal model architecture
([`gemma3multimodal`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/gemma3)).
[Try Gemma3 now](https://builds.modular.com/models/gemma-3-it/1B).
* Added the Qwen 3 model architecture
([`qwen3`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/qwen3)).
[Try Qwen3 now](https://builds.modular.com/models/Qwen3/1.7B).
* Added the InternVL3 model architecture
([`internvl`](https://github.com/modular/modular/tree/modular/v25.4.0/max/pipelines/architectures/internvl)).
This is still a work in progress.
* GGUF quantized Llamas (q4\_0, q4\_k, and q6\_k) are now supported with paged
KVCache strategy.
### MAX framework {#25-4-max}
#### Inference server {#25-4-max-serve}
* Inflight batching no longer requires chunked prefill.
* Expanded token sampling logic, including top\_k, min\_p, min\_new\_tokens,
temperature.
* Extended sampling configuration to be per-request, e.g. different requests
can ask for different sampling hyperparameters.
* Removed support for TorchScript and torch MLIR models.
#### `max` CLI {#25-4-max-cli}
* Added the `--use-subgraphs` flag to `max generate` to allow for the use of
subgraphs in the model.
* Added the `--port` option to specify the port number with the `max serve`
command.
#### Python API {#25-4-max-python}
* Lots of new APIs in the [`max.nn`](/max/api/python/nn/) package.
* Added `max.mojo.importer` module to import Mojo code into Python. See the
docs for [calling Mojo from Python](/mojo/manual/python/mojo-from-python).
* Added
[`Graph.add_subgraph()`](/max/api/python/graph/Graph#max.graph.Graph.add_subgraph)
to allow for the addition of a subgraph to a graph.
* Added
[`Module.build_subgraph()`](/max/api/python/nn/module#max.nn.module.Module.build_subgraph)
to allow for the creation of a subgraph for a layer that inherits from
`Module`.
* Added the [`call`](/max/api/python/graph/ops#max.graph.ops.call) op
which allows for the execution of a subgraph.
* Added the [`fold`](/max/api/python/graph/ops#max.graph.ops.fold) op for
combining sliding blocks into a larger tensor.
* Added [`KernelLibrary`](/max/api/python/graph/KernelLibrary) as an argument
type for the [`Graph`](/max/api/python/graph/Graph) constructor.
* Added
[`QuantizationConfig`](/max/api/python/graph/quantization#max.graph.quantization.QuantizationConfig)
to specify quantization parameters for ops such as
[`qmatmul()`](/max/api/python/graph/ops#max.graph.ops.qmatmul).
* Added the `strict` argument to the
[`Module.load_state_dict()`](/max/api/python/nn/module#max.nn.module.Module.load_state_dict)
method. When `strict=True` (default), an error is raised if the `state_dict`
contains unused keys. When `strict=False`, extra keys are ignored. This helps
model developers identify missing implementations in their models.
* Added audio generator APIs for text-to-speech models (such as
[`AudioGenerator`](/max/api/python/pipelines/core#max.pipelines.core.AudioGenerator),
[`PipelineAudioTokenizer`](/max/api/python/pipelines/core#max.pipelines.core.PipelineAudioTokenizer),
[`TTSContext`](/max/api/python/pipelines/core#max.pipelines.core.TTSContext),
and others). This is still a work in progress.
* The
[`ops.masked_scatter()`](/max/api/python/graph/ops#max.graph.ops.masked_scatter)
function now requires naming the `out_dim` explicitly as it is data-dependent.
For example:
```python
ops.masked_scatter(
inputs_embeds, video_mask, video_embeds, out_dim="unmasked_inputs"
)
```
* Deprecated the `CONTINUOUS` KVCache strategy
([`KVCacheStrategy`](/max/api/python/nn/kv_cache/cache_params/#max.nn.kv_cache.cache_params.KVCacheStrategy)).
Please use `PAGED` KVCache strategy instead.
* Removed the `Settings` argument from
[`LLM`](/max/api/python/entrypoints#max.entrypoints.llm.LLM) constructor. The
server is now automatically configured in the background without consuming an
HTTP port.
* Removed `Graph.unique_symbolic_dim()`.
* Removed `max_to_torch_type()` and `torch_to_max_type()` and replaced them with
[`DType.to_torch()`](/max/api/python/dtype#max.dtype.DType.to_torch) and
[`DType.from_torch()`](/max/api/python/dtype#max.dtype.DType.from_torch),
respectively. This aligns with the corresponding NumPy methods.
* Removed `stats_report` property and `reset_stats_report` method from
[`InferenceSession`](/max/api/python/engine#max.engine.InferenceSession). This
functionality was primarily used for internal PyTorch debugging and is no
longer needed.
* Removed the naive KVCache (`nn.kv_cache.naive_cache`).
* Removed `nn.attention` and `nn.naive_attention_with_rope`.
* Renamed `ops.select` to
[`ops.where`](/max/api/python/graph/ops#max.graph.ops.where). This matches the
name of the similar operation in torch and numpy.
#### Mojo API {#25-4-max-mojo}
* [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor/) now has a
`size` method to get the total number of elements.
* Following our [previous deprecation](#25-3-engine-mojo-api) of the Mojo
`max.driver`, `max.graph` and `max.engine` APIs, we've removed them from the
package and API docs.
As a result, we've also removed Mojo `max.tensor` APIs (including
`Tensor`, `TensorShape`, and `TensorSpec`). You can replace any use with
[`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor/).
#### Custom ops {#25-4-custom-ops}
* Improved error messages when custom op parameters are provided with values that
don't have the proper type.
* The [`ops.custom()`](/max/api/python/graph/ops#max.graph.ops.custom) function
now requires a `device` argument to specify where the operation should execute.
This avoids the need for custom ops to infer their execution device, which can
be error-prone.
* Added the [`max.torch`](/max/api/python/torch) module with the
`CustomOpLibrary` class for using custom Mojo kernels from PyTorch. For
example, with a custom `grayscale` operation written in Mojo:
```mojo
@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...
```
You can load it with PyTorch like so:
```python
from max.torch import CustomOpLibrary
op_library = CustomOpLibrary("path/to/custom.mojopkg")
@torch.compile(backend=backend)
def grayscale(pic):
result = pic.new_empty(pic.shape[:-1])
op_library.grayscale(result, pic)
return result
img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)
```
See our [tutorial to write custom ops for
PyTorch](/max/develop/custom-kernels-pytorch), and out [PyTorch custom
operation
examples](https://github.com/modular/modular/tree/main/max/examples/pytorch_custom_ops),
which range from a very basic "hello world" to the replacement of a layer in
a full model.
#### GPU programming {#25-4-gpu-programming}
* Full support for AMD CDNA3 datacenter GPUs is now available! Specifically,
MI300X and MI325X.
* Added initial support for programming on AMD RDNA3 consumer GPUs. Basic
tuning parameters have been specified for AMD Radeon 780m integrated GPUs. (AMD
RDNA3 support is for GPU programming only; AI models are still missing some GPU
kernels for this architecture.) For details, see the [GPU
requirements](/max/packages#gpu-compatibility).
* Now accepting CPU and GPU kernel contributions. See the [MAX AI kernels
contributing
guide](https://github.com/modular/modular/blob/main/max/kernels/CONTRIBUTING.md).
### Mojo language {#25-4-mojo}
For all the updates to the Mojo language, standard library, and tools, see the
[Mojo changelog](/mojo/changelog).
## v25.3 (2025-05-06)
* [Highlights](#25-3-highlights)
* [Documentation](#25-3-docs)
* [`max` CLI](#25-3-max-cli)
* [MAX models](#25-3-models)
* [MAX Serve](#25-3-serve)
* [MAX Engine & Graph](#25-3-engine)
* [Python API](#25-3-engine-mojo-api)
* [Mojo API](#25-3-engine-mojo-api)
* [Custom ops](#25-3-custom-ops)
* [Kernels](#25-3-kernels)
* [GPU programming](#25-3-gpu-programming)
* [Mojo language](#25-3-mojo)
### ✨ Highlights {#25-3-highlights}
* You can now **install Modular APIs and tools with pip**:
```sh
pip install modular \
--index-url https://download.pytorch.org/whl/cpu
```
This installs the `max` CLI, `max` Python library, `mojo` CLI, and Mojo
libraries. However, the Mojo LSP and debugger are currently not included.
We use the `--index-url` argument to ensure that `torch` installs its CPU
dependencies only, thus avoiding a lot of unnecessary GPU packages. This is a
temporary workaround until we can remove our dependency on `torch`.
* We **open-sourced the MAX AI kernels** and the rest of the **Mojo standard
library**!
The [MAX AI kernels library](/mojo/lib#max-ai-kernels-library) is a new Mojo
API for writing high-performance and portable programs across CPU and GPU, but
it's also [the source code for our CPU/GPU
kernels](https://github.com/modular/modular/tree/main/max/kernels/src). You
can now see the Mojo code we use in MAX to power GenAI workloads on CPUs and
GPUs.
Just like the Mojo standard library, these kernels are open source under the
Apache 2.0 License with LLVM exceptions. Plus, the rest of the Mojo standard
library is also [now open source in
GitHub](https://github.com/modular/modular/tree/main/mojo/std/src).
* **Learn to program GPUs** with [Mojo GPU Puzzles](https://builds.modular.com/puzzles)!
This is a brand new site that offers a hands-on guide to mastering GPU
programming with Mojo. Starting from basic concepts, you'll learn
step-by-step how to program for GPUs by solving increasingly challenging
puzzles.
### Documentation {#25-3-docs}
We've restructured the documentation to unify MAX and Mojo documentation
under the Modular Platform. We believe this improves content discovery with a
simplified navigation and helps unify the platform story as a whole.
We've also added the following new docs:
* [REST API reference](/max/api/serve): Although it's not a new API (our
serving library has supported OpenAI APIs for the last few versions), this
now shows precisely which endpoints and body parameters we support.
* [Speculative decoding](/max/serve/speculative-decoding): An introduction to
using speculative decoding to reduce latency for LLMs. This feature is still in
development.
* [Offline inference](/max/serve/offline-inference): An introduction to our
Python API for running inference with an LLM locally (without sending requests
to a serving endpoint).
* [Introduction to layouts](/mojo/manual/layout/layouts): A guide to working
with dense multidimensional arrays on CPUs and GPUs, using new Mojo `layout`
types that abstract-away complex memory layout patterns.
### `max` CLI {#25-3-max-cli}
* Renamed the `max-pipelines` CLI tool to `max`. We recommend re-installing
it as shown in the [`max` CLI docs](/max/cli/).
* Remove previously deprecated `--use-gpu`, `--serialized_model_path`,
`--save_to_serialized_model_path`, `--max_cache_batch_size` and
`--huggingface-repo-id` options.
* Move `InputContext`, `TextContext`, and `TextAndVisionContext` from
`max.pipelines` to `max.pipelines.context`.
### MAX models {#25-3-models}
* Added `Llama4ForConditionalGeneration` support,
featuring new MoE layers. Currently, it is limited to text inputs.
Run the model by calling:
```sh
max generate --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --devices 0,1,2,3
```
* Added support for running text generations using the Mistral 3 24B model.
Run the model with:
```sh
max generate --model-path mistralai/Mistral-Small-3.1-24B-Instruct-2503 --devices 0
```
* Fixed empty textual outputs for certain Mistral models
([MAX issue 4193](https://github.com/modular/modular/issues/4193)).
* Added support for loading a custom pipeline architecture by module. Using
`--custom-architectures=folder/path/to/import:my_module` will lead to loading
architectures from the file. The architectures must be exposed via an
`ARCHITECTURES` variable in the file. Once loaded, a model can be run using the
new architectures. The flag can be specified multiple times to load more
modules.
### MAX Serve {#25-3-serve}
* Moved from radix trie to hash based prefix caching implementation which has
smaller CPU overheads. This improves performance particularly in workloads with
high cache reuse rates.
* Added experimental support for offloading KVCache to host memory via the
`--enable-kvcache-swapping-to-host` and `--host-kvcache-swap-space-gb` flags.
This allows for superior KVCache reuse through prefix caching in workloads
where the reusable KVCache amount exceeds GPU VRAM.
* Fixed the `usage.prompt_tokens` field in the OpenAI API Usage Info response.
Previously this field was always set to Null, but now it correctly
contains the number of prompt tokens in the request.
* Switched from Python Multiprocessing Queue to ZeroMQ. This reduces latencies
between frontend server process and model worker process related to networking.
* Stray model workers on Linux now terminate more reliably when the parent
process is killed.
### MAX Engine & Graph {#25-3-engine}
#### Python API {#25-3-engine-python-api}
* We now raise an error if there's a mismatch between the expected device of a
weight on a graph and the device of the actual tensor data specified in
[`InferenceSession.load()`](/max/api/python/engine#max.engine.InferenceSession.load).
* Removed `output_device` argument from
[`Model.execute()`](/max/api/python/engine#max.engine.Model.execute).
* Removed the `copy_inputs_to_device` argument in
[`Model.execute`](/max/api/python/engine#max.engine.Model.execute) to improve
predictability of the API. Now `execute()` raises a `TypeError` if arguments
are passed whose devices don't match the model.
* Swapped the order of the `dtype` and `shape` fields of
[`driver.Tensor`](/max/api/python/driver#max.driver.Tensor).
Previously, the arguments are ordered as `(shape, dtype)`. They are now swapped
to `(dtype, shape)` to be in line with other tensor-like types.
* Replaced some instances of
[`Tensor.zeros`](/max/api/python/driver#max.driver.Tensor.zeros)
with `Tensor.__init__` when the engine did not depend on the tensor being zero
initialized. This elides the unnecessary memset to provide a minor performance
improvement.
* Added a new experimental
[`Tensor.inplace_copy_from()`](/max/api/python/driver#max.driver.Tensor.inplace_copy_from).
This allows users to copy the contents of one `Tensor` into another.
* Made the default behavior of [`Weight`](/max/api/python/graph/Weight) as
expecting the initial allocation on host. A transfer is then inserted to the
target device and this value is returned when weights generate an MLIR value.
This is done due to current conservative ownership around external weights.
* Added the [`irfft`](/max/api/python/graph/ops/#max.graph.ops.irfft) op, which
computes the inverse real fast fourier transform (FFT).
* Added the [`argmax`](/max/api/python/graph/ops#max.graph.ops.argmax) op,
which returns the index of the maximum value in an array or sequence.
* Added the [`GroupNorm`](/max/api/python/nn/norm/group_norm) layer.
* Switched layer names so that `max.nn` layers that are implemented with the
deprecated `Layer` class are marked as "V1", and layers that are implemented
with the new [`max.nn.Module`](/max/api/python/nn/module#max.nn.module.Module)
are the default. That is, `max.nn.LinearV2` is now
[`max.nn.Linear`](/max/api/python/nn/Linear), and the
previous `max.nn.Linear` is now
`max.nn.LinearV1`.
* DeviceRefs in types/layers are in general expected to be explicit rather than
implicit.
#### Mojo API {#25-3-engine-mojo-api}
* Removed some functionality from
[`tensor.Tensor`](/mojo/kernels/extensibility/tensor/tensor/Tensor):
* Serializing `Tensor` to disk (`Tensor.tofile(path)` and `Tensor.save(path)`).
* Reading the serialized data back from disk (`Tensor.load(path)` and
`Tensor.fromfile(path)`.
* `rand` and `randn` methods have been removed. Use the ones in the Mojo
standard library if you still need access for constructing a new `Tensor`
with random elements based on a particular `TensorShape`.
* **Deprecated the Mojo Driver, Graph, and Engine APIs**
These APIs are not currently used internally. Instead, we build graphs using
the Python APIs, and our engineering efforts have been focused on making that
experience as robust and user-friendly as possible. As a result, the Mojo
versions of these APIs have not kept pace with new features and language
improvements. These APIs will be open sourced for the community before being
removed.
#### Custom ops API {#25-3-custom-ops}
* You can now pass Mojo source package paths as
[`Graph`](/max/api/python/graph/Graph) custom extensions. The Mojo code will be
compiled automatically, no need to run `mojo package` manually as a prior step.
Previously, only pre-compiled `.mojopkg` paths were accepted, requiring the
Mojo code to be built as a prerequisite step before running a `Graph` with a
custom op.
Given a project structure like:
```text
project
|-- main.py
\-- kernels
|-- __init__.mojo
\-- my_custom_op.mojo
```
You can construct a `Graph` in `main.py` using Mojo custom op kernels simply
using:
```python
g = Graph(
...,
custom_extensions = [Path(__file__).parent / "kernels"]
)
```
A change to your Mojo source code defining a custom op will be reflected
immediately the next time the `Graph` is constructed.
* New [image\_pipeline example](https://github.com/modular/modular/tree/main/max/examples/custom_ops)
that demonstrates sequencing custom ops together which modify an image,
leaving data on the GPU for each op, before writing it back to CPU and disk.
### Kernels {#25-3-kernels}
* More compute overlap is now enabled for Hopper GPUs. This allows finer-grained
scheduling of kernel operations by analyzing producer-consumer patterns within
a compute kernel. As a result, there is more kernel compute overlap, especially
for compute-heavy kernels with data-dependent execution paths.
### GPU programming {#25-3-gpu-programming}
* CUDA driver requirement reduced to version 12.4 and the NVIDIA driver to be
version 550. Requiring these earlier driver versions allows MAX to be more
easily deployed on AWS and GCP, since these are the default versions used by
those cloud providers.
* Added support for programming NVIDIA Jetson Orin GPUs (`sm_87`).
Also see the [Mojo changelog of GPU changes](/mojo/changelog#gpu-changes).
### Mojo language {#25-3-mojo}
* We recently open-sourced the rest of the Mojo standard library, including the
`algorithm`, `benchmark`, `buffer`, `compile`, `complex`, `gpu`, and `layout`
packages. [See it all in
GitHub](https://github.com/modular/modular/tree/main/mojo/std/src).
* We've also open sourced [all our MAX AI
kernels](https://github.com/modular/modular/tree/main/max/kernels/src). This
new library includes `kv_cache`, `layout`, `linalg`, `nn`, `nvml`, and
`quantization`.
For all the updates to the Mojo language, standard library, and tools, see the
[Mojo changelog](/mojo/changelog).
## v25.2 (2025-03-25)
* [Highlights](#25-2-highlights)
* [MAX Serve](#25-2-serve)
* [MAX models](#25-2-models)
* [`max-pipelines` CLI](#25-2-pipelines-cli)
* [MAX Engine](#25-2-engine)
* [Driver APIs](#25-2-driver)
* [Graph APIs](#25-2-graph)
* [Custom ops](#25-2-custom-ops)
* [Hopper Kernels](#25-2-hopper-kernels)
* [GPU programming](#25-2-gpu-programming)
* [Mojo](#25-2-mojo)
* [Documentation](#25-2-documentation)
### ✨ Highlights {#25-2-highlights}
* **Support for NVIDIA Hopper GPUs**
MAX has been optimized to run on Hopper GPUs. For more information on MAX and
NVIDIA's hardware, see the [MAX
container](/max/container#recommended-cloud-instances) documentation.
* **Multi-GPU support**
MAX uses tensor parallelism to distribute work across multiple GPUs so you can
run LLMs like
[`Llama-3.3-70B-Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct),
even with long context window.
* **Expanded library of MAX models**
We're rapidly growing our library of base model architectures that MAX can
accelerate with MAX Serve (including `Phi3ForCausalLM`, `OlmoForCausalLM`,
and `GraniteForCausalLM`). We also now support `GTPQ` for the Llama models.
For more information, check out our [MAX model
repository](https://builds.modular.com/?category=models).
* **Advanced E2E optimizations for long context window**
In flight batching, chunked prefill, and copy-on-write optimize the execution
for prefix heavy and long context window scenario.
* **GPU programming with Mojo**
Lots of new APIs are now available to enable both low-level GPU programming and
abstracted programming patterns that simplify the code required to write GPU
kernels for your AI models.
### MAX Serve {#25-2-serve}
* Extended MAX Serve batch scheduling to account for the prefix cache. The
scheduler can now create larger batches when many prompt tokens are already
cached, improving throughput up to 10% in some benchmarks.
* Added support for in-flight batching, allowing token generation requests to be
scheduled alongside context encoding requests to reduce inter-token latency. This
behavior can be controlled by CLI argument `--enable-in-flight-batch`.
* Added support for copy-on-write on KV blocks when using PagedAttention with
Prefix Caching. This improves the prefix cache hit rate and prefill performance
in some scenarios.
* MAX Serve now supports `transformers` v.4.49.0, with a patch
to avoid graph breaks when using `torch.compile()` on Llama models.
* Added support for recording HTTP traffic out to a file for diagnostics or later
replay.
### MAX models {#25-2-models}
* Added support for executing `LlamaForCausalLM` architecture models on multiple
GPUs. The model uses tensor parallelism automatically when passing multiple
device IDs to the `--devices` CLI argument. Try running
[`meta-llama/Llama-3.3-70B-Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)
on 4 GPUs with the following example:
```sh
max-pipelines generate --model-path=meta-llama/Llama-3.3-70B-Instruct \
--quantization-encoding bfloat16 \
--devices gpu:0,1,2,3 \
--prompt="Design a
self-sustaining colony on Neptune's moon Triton with a myth/science
fusion name, three quantum tech breakthroughs, one ethical debate, a
neon-lit cultural ritual, and a hidden flaw—presented in bullet points."
```
* Added support for the `Phi3ForCausalLM` model architecture (such as
[`microsoft/phi-4`](https://huggingface.co/microsoft/phi-4)). For example:
```sh
max-pipelines generate \
--model-path microsoft/phi-4 \
--prompt "Write bubble sort in mojo"
```
* Added support for the `OlmoForCausalLM` model architecture (such as
[`allenai/OLMo-1B-0724-hf`](https://huggingface.co/allenai/OLMo-1B-0724-hf)). For
example:
```sh
max-pipelines generate \
--model-path allenai/OLMo-1B-0724-hf \
--prompt "Write bubble sort in mojo"
```
* Added support for the `GraniteForCausalLM` model architecture (such as
[`ibm-granite/granite-3.1-8b-instruct`](https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)).
For example:
```sh
max-pipelines generate \
--model-path ibm-granite/granite-3.1-8b-instruct \
--prompt "Write bubble sort in mojo"
```
* Added support for:
* [`microsoft/Phi-3.5-mini-instruct`](https://huggingface.co/microsoft/Phi-3.5-mini-instruct)
* [`microsoft/phi-4`](https://huggingface.co/microsoft/phi-4)
* [`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
* [`LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct`](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct)
* We now support GPTQ quantization for models that run on the GPU. This is
handled transparently when the model weights are specified. For example, this
runs Llama 3.1 8B using int4-quantized GPTQ weights:
```sh
max-pipelines generate \
--model-path hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4 \
--prompt "Why is the sky blue?" \
--max-batch-size 1 \
--max-length 10000
```
This reduces the total memory consumption of this model from \~16 GB to \~5 GB,
allowing the model to fit in the RAM smaller GPUs.
* Model weights are now downloaded in parallel.
* Added constraints on whitespace during [Structured
Output](/max/serve/structured-output). This reduces tokens counts and improves
model adherence.
* Added jump ahead decoding during Structured Output. This auto-completes tokens
when a singular path forward is identified, improving single completion times by
up to \~20% for long prompts.
* In the event of an unhandled exception, we now use the standard Python
traceback format instead of using pretty-printed Rich tracebacks.
* We now need to explicitly import `LLM` from
[`max.entrypoints.llm`](/max/api/python/entrypoints) rather than the previous
`max.entrypoints` import.
* The `max.pipelines.dataprocessing.tokenizer` and
`max.pipelines.dataprocessing.gguf_utils` modules have been removed.
* The previously deprecated `PipelineConfig.architecture` field and its
corresponding `--architecture` CLI argument have been removed.
### `max-pipelines` CLI {#25-2-pipelines-cli}
* The `--devices` CLI argument now supports a comma-separated list of GPU IDs
prefixed with `gpu:` like `--devices=gpu:0,1,2,3`. We no longer support the
previous `--devices=gpu-` format.
```sh
max-pipelines generate --model-path=meta-llama/Llama-3.3-70B-Instruct \
--quantization-encoding bfloat16 \
--devices gpu:0,1,2,3 \
--prompt="Design a self-sustaining colony on Neptune's moon Triton with a myth/science fusion name, three quantum tech breakthroughs, one ethical debate, a neon-lit cultural ritual, and a hidden flaw—presented in bullet points."
```
* Removed `--huggingface-repo-id`
[PipelineConfig](/max/api/python/pipelines/config/#max.pipelines.config.PipelineConfig)
option and CLI argument in favor of `--model-path`.
* We consolidated `--model-path` and `-weight-path`. Valid `--weight-path` values
now override `--model-path`, which handles both local and remote (Hugging Face)
cases. If we cannot derive the weights from the `--weight-path`, we now fall back
to the `--model-path`, which you must set explicitly.
* Added `--huggingface-revision` option, to allow selecting a non-default branch
or a specific commit in a Hugging Face model repository.
### MAX Engine {#25-2-engine}
* The MAX graph compiler now has kernel caching. This is a significant
improvement to our compilation pipeline. Here are some of the highlights:
* Up to 28% faster compilation times when making iterative changes to models
* Improved caching between different but similar models (up to 27% faster)
* Lays foundation for future caching optimizations
What does this mean for you? Faster development cycles! When you're working on
model pipelines and making changes to the graph, the graph compiler will now
intelligently reuse kernels that haven't changed, significantly reducing
compilation times.
The improvements are particularly noticeable during iterative development, with
compilation times dropping from \~80s to \~57s in some cases of compiling
Llama3.1-8B for 4 GPUs. Even when compiling different models from the same family
(like Llama/Granite variants), you'll see significant speedups on subsequent
compilations.
### Driver APIs {#25-2-driver}
* Added `Accelerator.can_access(other: Device) -> bool` method to check if one
device can directly access memory of another device.
* Fixed a bug in `max.driver.tensor.load_max_tensor()` for `bfloat16` dtype,
which would cause an error about mmap size being too large.
* `max.driver.Tensor.item()` now works on any single-element tensor (previously
restricted to rank-0 tensors).
* Added
[`Device.synchronize()`](/max/api/python/driver#max.driver.Device.synchronize),
which ensures all operations on the device complete before returning.
* Removed `MojoCallContextPtr` in favor of `DeviceContextPtr`.
`MojoCallContextPtr` only contained a `DeviceContextPtr`, so this change
directly exposes the `DeviceContextPtr`. Custom ops using `MojoCallContextPtr`
now directly take a `DeviceContextPtr` argument:
```mojo
@staticmethod
fn execute[
type: DType, rank: Int
](
output: OutputTensor[type=type, rank=rank],
input: InputTensor[type=type, rank=rank],
ctx: MojoCallContextPtr,
):
```
becomes
```mojo
@staticmethod
fn execute[
type: DType, rank: Int
](
output: OutputTensor[type=type, rank=rank],
input: InputTensor[type=type, rank=rank],
ctx: DeviceContextPtr,
):
```
* You can now skip compiling a GPU kernel first before enqueueing it, and pass
a function directly to `ctx.enqueue_function[func](...)`:
```mojo
fn func():
print("Hello from GPU")
@register("custom_op")
struct CustomOp:
@staticmethod
fn execute(ctx: DeviceContextPtr) raises:
var dev_ctx = ctx.get_device_context()
dev_ctx.enqueue_function[func](grid_dim=1, block_dim=1)
```
However, if you're reusing the same function and parameters multiple times, this
incurs some overhead of around 50-500 nanoseconds per enqueue. So you can still
compile the function first and pass it to `ctx.enqueue_function` in this scenario:
```mojo
var compiled_func = ctx.compile_function[func]()
# Multiple kernel launches with the same function/parameters
ctx.enqueue_function(compiled_func, grid_dim=1, block_dim=1)
ctx.enqueue_function(compiled_func, grid_dim=1, block_dim=1)
```
* Changed `Accelerator` and `CPU` from factory methods that created `Device`
objects in Python (which were accelerators and CPUs in the C++ implementation) to
actual Python types. This change elevates the `Accelerator` and `CPU` type
concepts to Python, making them types rather than methods.
This allows type annotations in Python. For example, a list of accelerators
used to be defined like this:
```python
graph_devices: list[DeviceRef]
```
Now it can be defined like this:
```python
graph_devices: list[Accelerator]
```
* Elementwise operations (e.g. `__add__`) have been removed from `Tensor`
(that is, `tensor_internal.Tensor`). This `Tensor` type is being phased out; please
reduce usage in favor of `LayoutTensor`.
### Graph APIs {#25-2-graph}
* The `nn` package is now [`max.nn`](/max/api/python/nn/).
* Added [`ops.chunk`](/max/api/python/graph#max.graphs.ops.chunk)) to support
chunking tensors along an axis.
* Added support for while loops with [`ops.while_loop`](/max/api/python/graph#max.graphs.ops.while_loop).
* Added support for conditional execution with [`ops.cond`](/max/api/python/graph#max.graph.ops.cond).
* Added axis reduction overloads for
[`ops.min`](/max/api/python/graph/ops#max.graph.ops.min) and
[`ops.max`](/max/api/python/graph/ops#max.graph.ops.max). For example;
`ops.min(tensor, axis=-1)`.
* The [`gelu()`](/max/api/python/graph/ops#max.graph.ops.gelu) function now accepts
an `approximate` keyword. The keyword controls the `gelu` approximation with
`none`, `tanh`, and `fast` approximations accepted.
* Removed the `roundeven()` operation from the Python API. The
[`round()`](/max/api/python/graph/ops#max.graph.ops.round) operation now has the
same behavior as `roundeven()`, so there is no need for both to exist.
* Added helpers to create analogous tensors from buffer types and vice versa.
* Added `max.nn.Module`, a base class for writing layers and constructing
networks of layers (e.g. using `max.nn.Sequential`). Currently, this class
supports graph building by ensuring that all weight names are unique and
systematically generated. This class also supports managing the weight values
with the `module.state_dict()` and `module.load_state_dict()` methods. More
functionality and documentation will be added in future releases.
### Custom ops {#25-2-custom-ops}
* Changes have been made to the way that custom ops are registered: rather
than using the `num_dps_outputs` attribute on `@compiler.register` to specify the
number of outputs, that number is now inferred from the signature of the custom
operation. Inputs to the operation now use the `InputTensor` type and outputs
from the operation use `OutputTensor`, instead of the previous
`ManagedTensorSlice` for both. This eliminates the need for a manual
`num_dps_outputs` attribute, and makes it safer to work with these inputs and
outputs by preventing accidental writes to input tensors. The new interface looks
something like the following:
```mojo
@compiler.register("add_one_custom")
struct AddOneCustom:
@staticmethod
fn execute[
target: StringLiteral,
](
out: OutputTensor,
x: InputTensor[type = out.type, rank = out.rank],
ctx: DeviceContextPtr,
) raises:
@parameter
@always_inline
fn elementwise_add_one[
width: Int
](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
return x.load[width](idx) + 1
foreach[elementwise_add_one, target=target](out, ctx)
```
* The `foreach` function now `raises` to be able to handle errors within an
elementwise calculation.
### Hopper kernels {#25-2-hopper-kernels}
State-of-the-Art Kernels in Mojo for H100/H200 GPUs
* **Hopper Architecture Matrix Multiplication Kernels**: The implementation
achieved performance comparable to NVIDIA's highly optimized cuBLAS library.
These kernels take full advantage of the Tensor Cores in Hopper architecture GPUs
to accelerate the fundamental matrix multiplication operations that underpin deep
learning workloads.
* **Multi-GPU AllReduce Implementation**: The AllReduce operation is critical for
distributed inference across multiple GPUs, as it efficiently aggregates
gradients. The Mojo implementation surpassed NVIDIA's NCCL library in performance
benchmarks. This improvement reduces communication overhead during distributed
inference.
* **MAX Attention Kernel with Flash Attention 3:** This implementation
incorporates the latest Flash Attention 3 algorithm and extends it, which
significantly accelerates the computation of attention mechanisms in transformer
models. The MAX attention kernel optimizes memory access patterns and
computational steps, reducing both the memory footprint and execution time of
attention operations. This is particularly important for LLMs where attention
calculations represent a substantial portion of the computational workload.
### GPU programming {#25-2-gpu-programming}
* Added the Mojo `max.driver` API to enable dispatching
GPU functions from Mojo.
Check out [examples for GPU programming in
Mojo](https://github.com/modular/modular/tree/main/mojo/examples/gpu-functions),
which use this new API.
### Mojo {#25-2-mojo}
Mojo is a crucial component of the MAX stack that enables all of MAX's
performance-oriented code across hardware. For all the updates to the Mojo
language, standard library, and tools, see the [Mojo
changelog](/mojo/changelog).
### Documentation {#25-2-documentation}
New examples for writing custom ops:
* [`fused_attention`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/fused_attention.mojo)
demonstrates complex GPU programming using MAX abstractions for a
practical use in AI model development.
* [`matrix_multiplication`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/matrix_multiplication.mojo)
includes a series of progressive optimizations for matrix multiplications
on GPUs.
* [`histogram`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/histogram.mojo)
shows how to implement the histogram pattern as a custom op.
* New [examples for GPU programming in
Mojo](https://github.com/modular/modular/tree/main/mojo/examples/gpu-functions)
using the new MAX Driver API
These use a Mojo programming model that should look familiar to CUDA C
programmers, showing how to define and dispatch GPU functions within a
single Mojo file. These examples recreate the first three samples from
the popular textbook ["Programming Massively Parallel
Processors"](https://www.amazon.com/Programming-Massively-Parallel-Processors-Hands/dp/0323912311),
showing how basic concepts translate from CUDA into Mojo. Additionally, a
Mandelbrot set calculation example that parallels a similar one in the
existing custom ops examples.
* New [MAX containers](/max/container/) available. For
more information on the base and full MAX containers, see [Container
contents](/max/container/#container-contents).
## v25.1.1 (2025-02-19)
Fix performance issues in autoregressive models with paged attention
by setting sensible default values for `--max-num-steps` that are
platform-specific.
## v25.1 (2025-02-13)
* [Highlights](#25-1-highlights)
* [Documentation](#25-1-docs)
* [MAX Serve](#25-1-serve)
* [MAX models](#25-1-max-models)
* [MAX Engine](#25-1-engine)
* [Graph APIs](#25-1-graph)
* [Pipeline APIs](#25-1-pipelines)
* [GPU programming](#25-1-gpus)
* [Mojo](#25-1-mojo)
### ✨ Highlights {#25-1-highlights}
* **Custom ops for GPUs**
Our new custom op API allows you to extend MAX Engine with new graph
operations written in Mojo that execute on either CPU or GPU, providing full
composability and extensibility for your models. See more in the section
about [GPU programming](#25-1-gpus).
* **Enhanced support for agentic workflows**
MAX Serve now supports function calling, which allows you to instruct your
model to interact with other systems, such as retrieve data and execute
external tasks. [Learn more about function calling and tool
use](/max/serve/function-calling).
MAX Serve now supports structured output (also known as constrained decoding)
for MAX models on GPU. This allows you to enforce the output format from a
model using an input schema that defines the output structure. [Learn more about
structured output](/max/serve/structured-output).
* **Extended model architecture support**
* MAX Serve now supports multimodal models that take both text and image
inputs. For example, see [how to deploy Llama 3.2
Vision](/max/tutorials/deploy-llama-vision).
* MAX Serve now supports text embedding models. Learn how to [deploy a text
embedding model](/max/tutorials/run-embeddings-with-max-serve).
* **New `max-pipelines` CLI tool**
Instead of cloning our GitHub repo to access our latest GenAI models, you can
instead install the `max-pipelines` CLI tool and quickly run an inference or
deploy an endpoint.
### Documentation {#25-1-docs}
New tutorials:
* [Build custom ops for GPUs](/max/develop/build-custom-ops)
* [Serverless GPU inference on Google Cloud
Run](/max/tutorials/deploy-serverless-cloud-run)
* [Generate image descriptions with Llama 3.2
Vision](/max/tutorials/deploy-llama-vision)
* [Deploy a text embedding model](/max/tutorials/run-embeddings-with-max-serve)
Other docs:
* [Function calling and tool use](/max/serve/function-calling)
* [Structured output](/max/serve/structured-output)
* [Prefix caching with PagedAttention](/max/serve/prefix-caching)
* `max-pipelines` CLI
### MAX Serve {#25-1-serve}
* The `/v1/completions` REST endpoint now supports:
* Pre-tokenized prompts.
* Image inputs for multimodal models such as `Llama-3.2-11B-Vision-Instruct`.
For an example, see [how to generate image
descriptions with Llama 3.2 Vision](/max/tutorials/deploy-llama-vision).
**Known issue:** You might receive faulty results because some parts of the
text prompt get ignored for certain input combinations. We've identified
the problem and will have a fix in a subsequent nightly
release.
* Function calling and tool use, which allows you to instruct your
model to interact with other systems, such as retrieve data and execute
external tasks. [Learn more about function calling and tool
use](/max/serve/function-calling).
* Structured output (also known as constrained decoding), which allows you to
enforce the output format from a model using a JSON schema and the
`response_format` field. To enable constrained decoding pass
`--enable-structured-output` when running the server. However, this feature
currently works for MAX models on GPU only (support for PyTorch models and
CPU is in progress). [Learn more about structured
output](/max/serve/structured-output).
* Added support for the `/v1/embeddings` API endpoint, allowing you to generate
vector representations using embedding models. See how to [deploy a text
embedding model](/max/tutorials/run-embeddings-with-max-serve).
* Max Serve can evict requests when the number of available pages in the
PagedAttention KVCache is limited. Before, the KV manager would throw an OOM
error when a batch that cannot fit in the cache was scheduled.
### MAX models {#25-1-max-models}
* Added the `max-pipelines` CLI tool that simplifies the
process to run inference with GenAI models (specified with a Hugging Face repo
ID) and deploy them to a local endpoint with MAX Serve.
Previously, running or serving these models required cloning the
[modular/max](https://github.com/modular/max) GitHub repo and then running
commands such as `magic run llama3`.
These model-specific commands like `llama3` and `replit` commands have been
removed. They're now standardized and subsumed by flags like
`--model-path` in the `max-pipelines` tool. Arguments such as
`--max-length` and `--weight-path` are also still supported by
`max-pipelines`.
To view a list of supported model architectures from Hugging Face, run
`max-pipelines list`.
* Added support for PagedAttention, which improves memory efficiency by
partitioning the KV cache into smaller blocks, reducing fragmentation and
enabling larger inference batches. You can enable it with
`--cache-strategy=paged` and `--kv-cache-page-size` with a value that's a
multiple of 128.
* Added support for prefix caching in all cases where PagedAttention is
supported. This allows for more efficient usage of KVCache and improved prefill
performance for workloads with common prefixes. You can enable it by setting
`--enable-prefix-caching`. For more information, see [Prefix caching with
PagedAttention](/max/serve/prefix-caching).
* Batch size and max length are now inferred from available memory and the HF
Models' default values for max length, respectively. If a configuration leads
to an OOM, then we provide recommendations (to the best of our ability) to the
user to fit the model into memory.
* Added support for heterogeneous KV caches for multi-modal models, such as
Llama Vision, which cache different KV states for self and cross attention
layers.
* Added support for embedding models, starting with MPNet. For example:
```shell
max-pipelines generate \
--model-path=sentence-transformers/all-mpnet-base-v2 \
--prompt="Encode this sentence."
```
Also see [how to deploy a text
embedding model](/max/tutorials/run-embeddings-with-max-serve).
* Added support for image and text multimodal models:
* `max-pipelines generate` now accepts image input with `--image_url`.
* Added an experimental Pixtral pipeline you can run as follows:
```shell
max-pipelines generate \
--model-path=mistral-community/pixtral-12b \
--prompt="What is in this image? [IMG]" \
--image_url=http://picsum.photos/1024/1024
```
The pipeline is automatically used for all models implementing the
`LlavaForConditionalGeneration` architecture.
The implementation currently has a limit of one image. We plan support an
arbitrary number of images of mixed sizes soon.
* Added an experimental Llama Vision pipeline you can run as follows:
```shell
max-pipelines generate \
--model-path=meta-llama/Llama-3.2-11B-Vision-Instruct \
--prompt="<|image|><|begin_of_text|>What is in this image?" \
--image_url=http://picsum.photos/1024/1024
```
The pipeline is automatically used for all models implementing the
`MllamaForConditionalGeneration` architecture.
Note: This model is gated and requires that you set the
[`HF_TOKEN`](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hftoken)
environment variable. See
[Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct).
* See [how to generate image
descriptions with Llama 3.2 Vision](/max/tutorials/deploy-llama-vision).
* Added support for the `Qwen2ForCausalLM` model architecture (such as
`Qwen/Qwen2.5-7B-Instruct`). For example:
```shell
max-pipelines generate \
--model-path=Qwen/Qwen2.5-7B-Instruct \
--prompt="Write bubble sort in python" \
--quantization-encoding bfloat16
```
* Added support for offline batched inference for text-based LLMs, allowing you
to load a model and run inference with a batch of inputs directly from Python,
instead of relying on an HTTP interface. For an example, see
[`examples/offline-inference/basic.py`](https://github.com/modular/modular/blob/main/examples/offline-inference/basic.py).
* The `--max-cache-batch-size` flag has been deprecated in favor of
`--max-batch-size`. Using `--max-cache-batch-size` now emits a deprecation
warning and will stop working in a future release.
* The `--use-gpu` flag has been deprecated in favor of `--devices=cpu`,
`--devices=gpu`, or `--devices=gpu-0,gpu-1,...`. If the device isn't specified,
the model runs on the first available GPU, or CPU if no GPUs are available.
### MAX Engine {#25-1-engine}
* Improved internal kernel compilation speed 1.5 - 4X across different models.
We've revamped our GPU compilation process so that all kernels in a program
are compiled together into a single LLVM module, then split into separate
kernels afterward. This ensures shared code between kernel entry points is
only compiled once. For example, we observe a 3.7x speed up for Llama3.1-8b
GPU startup time.
* Improved initial model execution speed on NVIDIA GPUs.
Instead of compiling to PTX and performing just-in-time compilation during
runtime, we now generate CUBIN binaries directly. While this increases
initial compilation time, it significantly improves execution speed.
* The kernels have been further tuned for performance on NVIDIA A100 GPUs.
#### Graph APIs {#25-1-graph}
* You can now write custom operations (ops) in Mojo, and add them to a graph
constructed in Python, using
[`custom()`](/max/api/python/graph/ops#max.graph.ops.custom) and
[`inplace_custom()`](/max/api/python/max/graph/ops#max.graph.ops.inplace_custom).
For more detail, see the section below about [GPU programming](#25-1-gpus).
* Cached compiled MAX graphs that make use of custom operations now get
invalidated when the implementation of the custom operations change.
* [`Graph.add_weight()`](/max/api/python/graph/Graph#max.graph.Graph.add_weight)
now takes an explicit `device` argument. This enables explicitly passing
GPU-resident weights to
[`session.load()`](/max/api/python/engine#max.engine.InferenceSession.load) via
the weights registry to initialize the model.
* [`max.graph.Weight`](/max/api/python/graph/Weight) now inherits
from `TensorValue`, allowing you to call `weight.cast()` or `weight.T`. As such,
the [`TensorValue`](/max/api/python/graph/TensorValue#max.graph.TensorValue) no
longer accepts `Weight` for the `value` argument.
#### Pipeline APIs {#25-1-pipelines}
* [`TextTokenizer.new_context()`](/max/api/python/pipelines/tokenizer#max.pipelines.tokenizer.TextTokenizer.new_context)
now supports tool definitions passed through its `request` argument (via
`TokenGeneratorRequest.tools`).
It also now supports JSON schemas passed through its `request` argument (via
[`TokenGeneratorRequest.response_format`](/max/api/python/pipelines/interfaces/#max.pipelines.interfaces.TokenGeneratorRequest.response_format)).
* Removed the default `num_steps` value for
[`TokenGenerator.next_token()`](/max/api/python/pipelines/interfaces/#max.pipelines.interfaces.TokenGenerator.next_token),
ensuring users pass a value, reducing the potential for silent errors.
* [`KVCacheStrategy`](/max/api/python/pipelines/kv_cache/cache_params#max.pipelines.kv_cache.cache_params.KVCacheStrategy)
now defaults to `MODEL_DEFAULT`.
As opposed to the previous setting which always used the "continuous" caching
strategy, KV caching strategy is now defaulted on an architecture-specific
basis to ensure the most optimized caching strategy is used.
* The
[`Linear`](/max/api/python/nn/Linear)
layer now has a `create()` class method that automatically creates
specializations of `Linear` for non-quantized, k-quant, or GPTQ layers.
* Added
[`nn.Conv1D`](/max/api/python/nn/conv#max.nn.conv.Conv1D)
for audio models like Whisper.
#### GPU programming {#25-1-gpus}
This release includes all new APIs to program on GPUs. The way to write code
for GPUs is to create custom operations with GPU functions that you can load
into a MAX graph. This foundational API includes a few key components:
* Mojo APIs to write custom op functions:
* The [`@compiler.register`](/max/api/mojo-decorators/compiler-register)
decorator is applied to a Mojo struct that implements a custom op in an
`execute()` function—for either CPU or GPU—and a `shape()` function that
defines the custom op's output tensor.
* The [`max.tensor`](/mojo/kernels/extensibility/tensor/) package adds
essential Mojo APIs for writing custom ops, such as:
* The [`foreach()`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/foreach)
function, which efficiently executes an element-wise computation in parallel
on either a GPU or CPU.
* The
[`ManagedTensorSlice`](/mojo/kernels/extensibility/tensor/managed_tensor_slice/ManagedTensorSlice)
type defines the input and output tensors for the custom op.
* Python APIs to load custom ops into a model:
* The [`custom()`](/max/api/python/graph/ops#max.graph.ops.custom) and
`inplace_custom()`
functions allow you to add the previously-defined Mojo custom op to a MAX
graph written in Python.
* The [`InferenceSession`](/max/api/python/engine#max.engine.InferenceSession)
constructor accepts the custom op implementation as a [Mojo
package](/mojo/manual/packages#mojo-packages) in the `custom_extensions`
argument.
For more detail, see the [tutorial to build custom ops for
GPUs](/max/develop/build-custom-ops), or check out this [simple example of
a custom
op](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/add_custom.mojo).
Additionally, we've added a new [`gpu` package](/mojo/std/gpu/) to the Mojo
standard library that provides low-level programming constructs for working
with GPUs. These APIs let you do things that you can't currently do with the
high-level `foreach()` abstraction above. The Mojo `gpu` APIs allow you to
manually manage interaction between the CPU host and GPU device, manage memory
between devices, synchronize threads, and more. For some examples, see
[`vector_addition.mojo`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/vector_addition.mojo)
and
[`top_k.mojo`](https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/top_k.mojo).
### Mojo {#25-1-mojo}
Mojo is a crucial component of the MAX stack that enables all of MAX's
performance-oriented code across hardware. For all the updates to the Mojo
language, standard library, and tools, see the [Mojo
changelog](/mojo/changelog).
## v24.6 (2024-12-17)
This is a huge update that offers a first look at our serving library for
MAX on GPUs!
* [Highlights](#24-6-highlights)
* [Documentation](#24-6-docs)
* [MAX Serve](#24-6-serve)
* [MAX models](#24-6-models)
* [MAX Engine](#24-6-engine)
* [Driver APIs](#24-6-driver-api)
* [Graph compiler](#24-6-graph-compiler)
* [Graph APIs](#24-6-graph-api)
* [Custom op registration](#24-6-custom-ops)
* [Numeric kernels](#24-6-kernels)
* [Mojo](#24-6-mojo)
Also check out our [blog post introducing MAX
24.6](https://www.modular.com/blog/introducing-max-24-6-a-gpu-native-generative-ai-platform).
### ✨ Highlights {#24-6-highlights}
* **MAX Engine on GPUs preview**
We're excited to share a preview of MAX Engine on GPUs. We've created a few
tutorials that demonstrate MAX's ability to run GenAI models with our
next-generation MAX graph compiler on NVIDIA GPU architectures (including
A100, A10, L4, and L40 GPUs). You can experience it today by [deploying
Llama 3 on an A100 GPU](/max/tutorials/max-serve-local-to-cloud).
* **MAX Serve preview**
This release also includes an all-new serving interface called MAX
Serve. It's a Python-based serving layer that supports both
native MAX models when you want a high-performance deployment, and
off-the-shelf PyTorch LLMs from Hugging Face when you want to explore and
experiment—all with GPU support. It provides an OpenAI-compatible REST
endpoint for inference requests, and a Prometheus-compatible metrics
endpoint. You can use a `magic` command to start a local server , or use our
ready-to-deploy MAX container to start an endpoint in the cloud. Try it now
[with an LLM from Hugging Face](/max/tutorials/max-serve-local-to-cloud).
* **Upgraded MAX models**
As we continue to build our Python-based MAX Graph API that allows you to
build high-performance GenAI models, we've made a ton of performance
improvements to the existing models and added a few new models to our GitHub
repo. All the Python-based MAX models now support GPUs and broad model
architectures. For example,
[`llama3`](https://github.com/modular/modular/tree/main/max/pipelines/architectures/llama3)
adds compatibility for the LlamaForCausalLM family, which includes over
20,000 model variants and weights on Hugging Face.
### Documentation {#24-6-docs}
New tutorials:
* [Deploy Llama 3 on GPU with MAX
Serve](/max/tutorials/max-serve-local-to-cloud)
* [Deploy Llama 3.1 on GPU-powered Kubernetes
clusters](/max/tutorials/deploy-max-serve-on-kubernetes)
* [Get started with MAX Graph in
Python](/max/tutorials/get-started-with-max-graph-in-python)
Other new docs:
* [MAX container](/max/container)
* [Benchmark MAX
Serve](https://github.com/modular/modular/tree/main/benchmark)
Also, our documentation is now available for **MAX nightly builds**! If you're
building with a nightly
release, you can
switch to see the nightly docs using a toggle to the right of the search bar.
### MAX Serve {#24-6-serve}
This release includes a preview of our Python-based serving library called MAX
Serve. It simplifies the process to deploy your own inference
server with consistent and reliable performance.
MAX Serve currently includes the following features:
* Deploys locally and to the cloud with our [MAX container
image](/max/container), or with the `magic` CLI.
* An OpenAI-compatible server with streaming `/chat/completion` and
`/completion` endpoints for LLM inference requests.
* Prometheus-compatible [metrics endpoint](/max/container#metrics) with LLM
KPIs (TTFT and ITL) for monitoring and evaluating performance.
* Supports most `TextGeneration` Hugging Face Hub models.
* Multiprocess HTTP/model worker architecture to maximize CPU core utilization
by distributing multiple incoming requests across multiple processes, ensuring
both high throughput and responsiveness.
* Continuous heterogeneous batching to combine multiple incoming requests into
a single inference (no waiting to fill a batch size) and improve total
throughput.
There's much more still in the works for MAX Serve, but you can try it today
with our tutorials to [Deploy Llama 3 on GPU with MAX
Serve](/max/tutorials/max-serve-local-to-cloud).
**Known issues:**
* While this release is enough to support typical chatbot applications,
this release does not yet support the function-calling portion of the
OpenAI API specification needed to enable robust agentic workflows.
* Sampling is still limited and doesn't currently respect temperature or
other sampling-related API request input.
* Structured generation is not supported.
* Support for multi-modal models is still nascent.
### MAX models {#24-6-models}
All of our Python-based GenAI [models on
GitHub](https://github.com/modular/modular/tree/main/max/pipelines/architectures)
now support GPUs!
As we add more models, we're also building a robust set of libraries and
infrastructure that make it easier to build and deploy a growing library of
LLMs. Some of which is available in a new
[`max.pipelines`](/max/api/python/pipelines/) package and some of it is
alongside the [models on
GitHub](https://github.com/modular/modular/tree/main/max/pipelines/architectures).
Here are just some of the highlights:
* Deep integration with the Hugging Face ecosystem for a quick-to-deploy
experience, such as using HF Model Hub tools to fetch config files, support for
weights in [safetensor](https://github.com/huggingface/safetensors) format,
support for HF tokenizers, and more. (We also support GGUF weight formats.)
* Expanded set of model abstractions for use by different LLM architectures:
* Attention layers (including highly optimized implementations with
configurable masking, like
[`AttentionWithRope`](https://github.com/modular/modular/tree/main/max/nn/attention/attention_with_rope.py)).
The optimized attention layers include variants that accept an attention
mask. More memory-efficient variants that don't take a mask instead take a
"mask functor" argument to the kernel, which implements masking without
materializing a mask by computing a mask value from input coordinates on the
fly.
* Transformers such as [`Transformer` and
`TransformerBlock`](https://github.com/modular/modular/tree/main/max/nn/transformer/transformer.py).
These include an initial implementation of ragged tensors—tensors for which
each dimension can have a different size, avoiding the use of padding tokens
by flattening a batch of sequences of differing lengths.
* Common layers such as
[`RMSNorm`](https://github.com/modular/modular/tree/main/max/nn/norm/rms_norm.py)
,
[`Embedding`](https://github.com/modular/modular/tree/main/max/nn/embedding.py),
and
[`Sequential`](https://github.com/modular/modular/tree/main/max/nn/sequential.py).
* KV cache management helpers, like
[`ContinuousBatchingKVCacheManager`](/max/api/python/pipelines/kv_cache/continuous_batching_cache#max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager).
* Low-level wrappers over optimized kernels like
[`fused_qk_ragged_rope`](https://github.com/modular/modular/tree/main/max/nn/kernels.py).
These are custom fused kernels that update the KV cache in place. Although
they are custom, they reuse the underlying kernel implementation by passing
in lambda functions used to retrieve inputs and write to outputs in place.
* Added generalized interfaces for text generation such as
[`TokenGenerator`](/max/api/python/pipelines/interfaces#max.pipelines.interfaces.TokenGenerator)
and
[`PipelineModel`](/max/api/python/pipelines/pipeline#max.pipelines.pipeline.PipelineModel),
which provide modularity within the models and serving infrastructure. Also
added a plug-in mechanism
([`PipelineRegistry`](/max/api/python/pipelines/registry#max.pipelines.registry.PipelineRegistry))
to more quickly define new models, tokenizers, and other reusable components.
For example, anything that conforms to
[`TokenGenerator`](/max/api/python/pipelines/interfaces#max.pipelines.interfaces.TokenGenerator)
can be served using the LLM infrastructure within MAX Serve. We then used this
interface to create the following:
* An optimized
[`TextGenerationPipeline`](/max/api/python/pipelines/pipeline#max.pipelines.pipeline.TextGenerationPipeline)
that can be combined with any compatible graph and has powerful performance
features like graph-based multi-step scheduling, sampling, KV cache
management, ragged tensor support, and more.
* A generic
[`HFTextGenerationPipeline`](/max/api/python/pipelines/hf_pipeline#max.pipelines.hf_pipeline.HFTextGenerationPipeline)
that can run any Hugging Face model for which we don't yet have an optimized
implementation in eager mode.
* Models now accept weights via a weights registry, which is passed to the
[`session.load()`](/max/api/python/engine#max.engine.InferenceSession.load)
method's `weights_registry` argument. The decoupling of weights and model
architecture allows implementing all of the different fine-tunes for a given
model with the same graph. Furthermore, because the underlying design is
decoupled, we can later expose the ability to compile a model once and swap
weights out on the fly, without re-compiling the model.
* Added generic implementations of common kernels, which allow you to plug-in
different batching strategies (ragged or padded), KV cache management
approaches (continuous batching), masking (causal, sliding window, etc.), and
position encoding (RoPE or ALIBI) without having to re-write any kernel code.
(More about this in a future release.)
* Multi-step scheduling to run multiple token-generation steps on GPU before
synchronizing to the CPU.
**Updated models:**
* Significant performance upgrades for [Llama
3](https://github.com/modular/modular/tree/main/max/pipelines/architectures/llama3),
and expanded compatibility with the `LlamaForCausalLM` models family. For
example, it also supports Llama 3.2 1B and 3B text models.
**New models:**
* [Mistral
NeMo](https://github.com/modular/modular/tree/main/max/pipelines/architectures/mistral)
(and other `MistralForCausalLM` models)
* [Replit Code V1.5
3B](https://github.com/modular/modular/tree/main/max/pipelines/architectures/replit)
**Known issues:**
* The Q4 quantized models currently work on CPU only.
* Using a large setting for `top-k` with the Llama 3.1 model may lead to
segmentation faults for certain workloads when run on NVIDIA GPUs. This should
be resolved in the latest nightly MAX builds.
* The models currently use a smaller default context window than the
`max_seq_len` specified in the Hugging Face configuration files for a given
model. This can be manually adjusted by setting the `--max-length` parameter to
the desired context length when serving a model.
* Some variants of the supported core models (like `LlamaForCausalLM` with
different number of heads, head sizes, etc.) might not be fully optimized yet.
We plan to fully generalize our implementations in a future release.
### MAX Engine {#24-6-engine}
MAX Engine includes a lot of the
core infrastructure that enables MAX to accelerate AI models on any hardware,
such as the graph compiler, runtime, kernels, and the APIs to interact with it
all, and it all works without external dependencies such as PyTorch or CUDA.
This release includes a bunch of performance upgrades to our graph compiler and
runtime. We've added support for NVIDIA GPU architectures (including A100, A10,
L4, and L40 GPUs), and built out new infrastructure so we can quickly add
support for other GPU hardware.
**Engine API changes:**
* [`InferenceSession`](/max/api/python/engine#max.engine.InferenceSession)
now accepts a `custom_extensions` constructor argument, same as `load()`, to
specify model extension libraries.
* The [`Model`](/max/api/python/engine#max.engine.Model) object is now callable
to run an inference.
**Breaking changes**:
* `Model.execute()` signature changed to support GPUs.
* The [`execute()`](/max/api/python/engine#max.engine.Model.execute) function
currently doesn't accept keyword arguments. Instead you can pass tensors as a
[`driver.Tensor`](/max/api/python/driver#max.driver.Tensor), `int`, `float`,
`bool`,
[`np.generic`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic),
or [`DLPackArray`](/max/api/python/driver#max.driver.DLPackArray)
([DLPack](https://github.com/dmlc/dlpack)). Note that both PyTorch and NumPy
arrays implement the DLPack protocol, which means you can also pass either of
those types to `execute()`.
* [`execute_legacy()`](/max/api/python/engine#max.engine.Model.execute_legacy)
preserves the semantics of `execute()` with support for keyword arguments to
help with migration, but will be removed in a future release.
`execute_legacy()` doesn't support GPUs.
* Calling `execute()` with positional arguments still works the same.
#### Driver APIs {#24-6-driver-api}
MAX Driver (the [`max.driver`](/max/api/python/driver) module) is a new
component of MAX Engine that's still a work in progress. It provides primitives
for working with heterogeneous hardware systems (GPUs and CPUs), such as to
allocate on-device memory, transfer data between host and device, query device
stats, and more. It's a foundation on which other components of MAX Engine
operate (for example, `InferenceEngine` now uses
[`driver.Tensor`](/max/api/python/driver#max.driver.Tensor) to handle model
inputs and outputs).
**Driver API changes:**
* Added `CUDA()` device to open an NVIDIA GPU.
* Added support for fp16 and bfloat16 dtypes.
* Expanded functionality for `max.driver.Device`, with new class methods and
properties. We are still working on building this out to support more
accelerator features.
* [`driver.Tensor`](/max/api/python/driver#max.driver.Tensor) (and the
`InferenceSession.load()` argument `weights_registry` ) now supports zero-copy
interoperability with NumPy arrays and PyTorch tensors, using
[DLPack](https://github.com/dmlc/dlpack) /
[`DLPackArray`](/max/api/python/driver#max.driver.DLPackArray).
* [`driver.Tensor`](/max/api/python/driver#max.driver.Tensor) has new methods,
such as `from_dlpack()`, `element_size()` , `to()`, `to_numpy()`, `view()`,
`zeros()`, and more.
MAX Driver APIs are still changing rapidly and not yet ready for general use.
We'll publish more documentation in a future release.
**Known issues:**
* MAX Driver is currently limited to managing just one NVIDIA GPU at a time (it
does not yet support multi-GPU). It also does not yet support remote devices.
* DLPack support is not complete. For example, streams are not yet supported.
#### Graph compiler {#24-6-graph-compiler}
When you load a model into MAX Engine, the graph compiler is the component that
inspects and optimizes all graph operations (ops) to deliver the best run time
performance on each device.
This release includes various graph compiler improvements:
* Major extensions to support NVIDIA GPUs (and other devices in the future),
including async copies and caching of JIT'd kernels.
* The runtime now performs scheduling to enable GPU compute overlap with the
CPU.
* New transformations to the Mojo kernels to enable a number of optimizations,
including specialization on tensor dimensions, specialization on target
hardware, specialization on non-tensor dimension input to kernels, automatic
kernel fusion between operators, and more.
* New algebraic simplifications and algorithms for ops such as horizontal
fusion of matrix multiplications.
* New CPU-side primitives for device management that are automatically
transformed and optimized to reduce overhead (MAX does not need to use things
like CUDA Graphs).
* Updated memory planning to preallocate device memory (hoist computation from
inference runtime to initialization time) and reduce per-inference overhead.
#### Graph APIs {#24-6-graph-api}
The graph compiler is also exposed through the MAX Graph APIs (the
[`max.graph`](/max/api/python/graph/) package), which allow you to build
high-performance GenAI models in Python.
**Graph API changes:**
* Python stack traces from model execution failures now include a trace to the
original op-creation, allowing for easier debugging during development.
* The [`max.graph`](/max/api/python/graph/) APIs now include preliminary
support for symbolic algebraic expressions using
[`AlgebraicDim`](/max/api/python/graph/type#max.graph.type.AlgebraicDim),
enabling more powerful support for checked dynamic shapes. This allows
`-Dim("x") - 4`. Furthermore, the algebraic expressions simplify to a canonical
form, so that for example `-Dim("x") - 4 == -(Dim("x") + 4)` holds.
* More advanced dtype promotion now allows
[`TensorValue`](/max/api/python/graph/TensorValue) math operators to just work
when used with NumPy arrays and python primitives.
* [`TensorValue`](/max/api/python/graph/TensorValue) has new methods, such as
`broadcast_to()`, `cast()`, `flatten()`, `permute()`, and more.
* Added [`BufferValue`](/max/api/python/graph/BufferValue), which allows for
device-resident tensors that are read and mutated within the graph.
* [`DType`](/max/api/python/dtype#max.dtype.DType) has new methods/properties,
`align`, `size_in_bytes`, and `is_float()`.
* [`Value`](/max/api/python/graph/Value) constructor accepts more types for
`value`.
* [`TensorValue`](/max/api/python/graph/TensorValue) constructor accepts more
types for `value`.
* [`TensorValue.rebind()`](/max/api/python/graph/TensorValue#max.graph.TensorValue.rebind)
accepts a new `message` argument.
**Breaking changes:**
* [`Graph.add_weight()`](/max/api/python/graph/Graph#max.graph.Graph.add_weight)
now accepts [`Weight`](/max/api/python/graph/Weight#max.graph.Weight) and
returns [`TensorValue`](/max/api/python/graph/TensorValue).
[`Weight`](/max/api/python/graph/Weight#max.graph.Weight) is essentially a
named placeholder for a tensor that knows its name, dtype, shape, and
optionally device and quantization encoding. `Graph.add_weight()` stages an op
in the graph that is populated by a named weight in the weights registry passed
to `session.load`.
* The [`Weight`](/max/api/python/graph/Weight#max.graph.Weight) constructor
arguments changed; added `align` , `dtype` , and `shape`; removed `assign` ,
`filepath`, `offset`, and `value`.
* The `ops.scalar()` method was removed along with the `is_static()` and
`is_symbolic()` methods from all `graph.type` objects.
* Instead of `ops.scalar()`, use
[`ops.constant()`](/max/api/python/graph/ops#max.graph.ops.constant).
* Instead of `is_static()` and `is_symbolic()`, use
`isinstance(dim, SymbolicDim)` and `isinstance(dim, StaticDim)`.
The MAX Graph APIs are not ready for general use but you can [experiment with
it now by following this
tutorial](/max/tutorials/get-started-with-max-graph-in-python). We'll add more
documentation when we finish some API redesigns.
#### Custom op registration {#24-6-custom-ops}
Although the APIs to write custom operators (ops) isn't ready for general use,
this release includes a significant redesign that lays the groundwork. You
might notice some associated APIs in this release and more APIs in the
nightlies, so here's a little about the work in progress:
* The custom op APIs will allow you to extend MAX Engine with new ops written
in Mojo, providing full composability and extensibility for your models. It's
the exact same API we use to write MAX Engine's built-in ops such as `matmul`.
That means your custom ops can benefit from all our compiler optimization
features such as kernel fusion—your ops are treated the same as all the ops
included "in the box."
* The new API requires far less adornment at the definition site to enable the
MAX model compiler to optimize custom ops along with the rest of the graph
(compared to our previous version that used `NDBuffer`).
* Custom ops support "destination passing style" for tensors.
* The design composes on top of Mojo's powerful meta programming, as well as
the kernel libraries abstractions for composable kernels.
We'll publish more documentation when the custom op API is ready for general
use. Check out the MAX repo's `nightly` branch to see the latest [custom op
examples](https://github.com/modular/modular/tree/main/max/examples/custom_ops).
**Known issues:**
* Custom ops don't have type or lifetime checking. They also don't reason about
mutability. Expect lots of sharp corners and segfaults if you hold them wrong
while we improve this!
#### Numeric kernels {#24-6-kernels}
The GPU kernels for MAX Engine are built from the ground up in Mojo with no
dependencies on external vendor code or libraries. This release includes the
following kernel improvements:
* AttenGen: a novel way to express attention pattern that's able to express
different attention masks, score functions, as well as caching strategies.
* State-of-the-art matrix multiplication algorithms with optimizations such as
the following:
* Pipelining and double-buffering to overlap data transfer and computation
and to hide memory access latency (for both global and shared memory).
* Thread swizzling to avoid shared memory bank conflicts associated with
tensor core layouts.
* Block swizzling to increase L2 cache locality.
* SplitK/StreamK GEMM algorithms: divides the computation along the shared K
dimension into smaller matrices which can then be executed independently on
streaming multiprocessors (such as CUDA cores). These algorithms are ideal for
matrices with large K dimension but small M dimension.
* Large context length MHA: uses SplitK/StreamK to implement the attention
mechanism and eliminate the need of a huge score matrix, which drastically
reduces memory usage/traffic to enable large context length.
* DualGemm: accelerates the multi-layer perceptron (MLP) layers where the
left-hand side (LHS) is shared between two matrix multiplications.
**Known issues:**
* The MAX kernels are optimized for bfloat16 on GPUs.
* Convolution on GPU is not performance optimized yet.
* Although v24.6 technically runs on H100, it doesn't include
performance-optimized kernels for that device yet and it isn't recommended.
### Mojo {#24-6-mojo}
Mojo is a crucial component of the MAX stack that enables all of MAX's
performance-oriented code across hardware. For all the updates to the Mojo
language, standard library, and tools, see the [Mojo
changelog](/mojo/changelog#v246-2024-12-17).
## v24.5 (2024-09-13)
### ✨ Highlights
* Mojo and MAX are magical! We've created a new package and virtual environment
manager, `magic`, for MAX and Mojo.
* New [Llama3.1
pipeline](https://github.com/modular/modular/tree/main/max/pipelines/architectures)
built with the new MAX Graph Python API.
* We have not one, but two new Python APIs that we're introducing in this
release:
* [MAX Graph Python API](#max-graph-python-api)
* [MAX Driver Python API](#max-driver-python-api)
### ⭐️ New
* Added `repeat_interleave` graph op.
* Added caching for MAX graph models.
This means that graph compilation is cached and the executable model is
retrieved from cache on the 2nd and subsequent runs.
Note that the model cache is architecture specific and isn't portable across
different targets.
* Support for Python 3.12.
#### MAX Graph Python API
This Python API
will ultimately provide the same low-level programming interface for
high-performance inference graphs as the Mojo API. As with the Mojo API, it's an
API for graph-building only, and it does not implement support for training.
You can take a look at how the API works in the
[MAX Graph Python API reference](/max/api/python/graph/).
#### MAX Driver Python API
The MAX Driver API allows you to interact with devices (such as CPUs and GPUs)
and allocate memory directly onto them. With this API, you interact with
this memory as tensors.
Note that this API is still under development, with support for non-host
devices, such as GPUs, planned for a future release.
To learn more, check out the
[MAX Driver Python APIreference](/max/api/python/driver).
#### MAX C API
New APIs for adding torch metadata libraries:
* `M_setTorchMetadataLibraryPath`
* `M_setTorchMetadataLibraryPtr`
### 🦋 Changed
#### MAX Engine performance
* Compared to v24.4, MAX Engine v24.5 generates tokens for Llama an average of
15%-48% faster.
#### MAX C API
Simplified the API for adding torch library paths, which now only takes one path
per API call, but can be called multiple times to add paths to the config:
* `M_setTorchLibraries` -> `M_setTorchLibraryPath`
### ⚠️ Deprecated
* The `max` command line tool is no longer supported and will be removed
in a future release.
### ❌ Removed
* Dropped support for Ubuntu 20.04. If you're using Ubuntu, we currently
support Ubuntu 22.04 LTS only.
* Dropped support for Python 3.8.
* Removed built-in PyTorch libraries from the max package. See the
[FAQ](/max/faq) for information on supported torch versions.
## v24.4 (2024-06-07)
### 🔥 Legendary
* MAX is now available on macOS! [Try it now](/max).
* New quantization APIs for MAX Graph. You can now build high-performance
graphs in Mojo that use the latest quantization techniques, enabling even
faster performance and more system compatibility for large models.
Learn more in the guide to [quantize your graph weights](/max/graph/quantize).
### ⭐️ New
#### MAX Mojo APIs
* Added AI pipeline examples in the `max` repo, with Mojo implementations for
common transformer layers, including quantization support.
* New Llama3 pipeline built with MAX Graph.
* New Replit Code pipeline built with MAX Graph.
* New TinyStories pipeline (based on TinyLlama) that offers a simple demo of
the MAX Graph quantization API.
* Added `max.graph.checkpoint` package
to save and load model weights.
All weights are stored in a
`TensorDict`.
You can save and load a `TensorDict` to disk with
`save()` and
`load()` functions.
* Added MAX Graph quantization APIs:
* Added quantization encodings
`BFloat16Encoding`,
`Q4_0Encoding`,
`Q4_KEncoding`,
and
`Q6_KEncoding`.
* Added the
`QuantizationEncoding`
trait so you can build custom quantization encodings.
* Added `Graph.quantize()`
to create a quantized tensor node.
* Added `qmatmul()` to
perform matrix-multiplication with a float32 and a quantized matrix.
* Added some MAX Graph ops:
* `avg_pool()`
* `max_pool()`
* `conv2d()`
* `conv3d()`
* `layer_norm()`
* `tile()`
* `select()`
* Added a `layer()` context
manager and
`current_layer()`
function to aid in debugging during graph construction. For example:
```mojo
with graph.layer("foo"):
with graph.layer("bar"):
print(graph.current_layer()) # prints "foo.bar"
x = graph.constant[DType.int64](1)
graph.output(x)
```
This adds a path `foo.bar` to the added nodes, which will
be reported during errors.
* Added
`format_system_stack()`
function to format the stack trace, which we use to print better error
messages from `error()`.
* Added
`TensorMap.keys()` to
get all the tensor key names.
#### MAX C API
Miscellaneous new APIs:
* `M_cloneCompileConfig()`
* `M_copyAsyncTensorMap()`
* `M_tensorMapKeys()` and `M_deleteTensorMapKeys()`
* `M_setTorchLibraries()`
### 🦋 Changed
#### MAX Mojo API
* `EngineNumpyView.data()`
and `EngineTensorView.data()`
functions that return a type-erased pointer were renamed to `unsafe_ptr()`.
* `TensorMap` now conforms
to `CollectionElement` trait to be copyable and movable.
* `custom_nv()` was removed, and its functionality moved into
`custom()` as a function
overload, so it can now output a list of tensor symbols.
## v24.3 (2024-05-02)
### 🔥 Legendary
* You can now write custom ops for your models with Mojo!
Learn more about [MAX extensibility](/max/develop/custom-ops).
### 🦋 Changed
* Added support for named dynamic dimensions. This means you can specify when two
or more dimensions in your model's input are dynamic but their sizes at run
time must match each other. By specifying each of these dimension sizes with a
name (instead of using `None` to indicate a dynamic size), the MAX Engine
compiler can perform additional optimizations. See the notes below for the
corresponding API changes that support named dimensions.
* Simplified all the APIs to load input specs for models, making them more
consistent.
#### MAX Engine performance
* Compared to v24.2, MAX Engine v24.3 shows an average speedup of 10% on PyTorch
models, and an average 20% speedup on dynamically quantized ONNX transformers.
#### MAX Graph API
The `max.graph` APIs are still changing
rapidly, but starting to stabilize.
* `AnyMoType` renamed to `Type`,
`MOTensor` renamed to
`TensorType`, and `MOList`
renamed to `ListType`.
* Removed `ElementType` in favor of using `DType`.
* Removed `TypeTuple` in favor of using `List[Type]`.
* Removed the `Module` type so you can now start building a graph by directly
instantiating a `Graph`.
* Some new ops in `max.ops`, including
support for custom ops.
See how to [create a custom op in MAX
Graph](/max/develop/build-custom-ops).
#### MAX Engine Python API
* Redesigned
[`InferenceSession.load()`](/max/api/python/engine#max.engine.InferenceSession.load)
to replace the confusing `options` argument with a `custom_ops_path` argument.
As a result, `CommonLoadOptions`, `TorchLoadOptions`, and
`TensorFlowLoadOptions` have all been removed.
* [`TorchInputSpec`](/max/api/python/engine#max.engine.TorchInputSpec)
now supports named dynamic dimensions (previously, dynamic dimension sizes
could be specified only as `None`). This lets you tell MAX which dynamic
dimensions are required to have the same size, which helps MAX better optimize
your model.
#### MAX Engine Mojo API
* `InferenceSession.load_model()` was renamed to
`load()`.
* Redesigned
`InferenceSession.load()`
to replace the confusing `config` argument with a `custom_ops_path` argument
for use when [loading a custom op](/max/develop/build-custom-ops), and an
`input_specs` argument for use when loading TorchScript models.
Doing so removed `LoadOptions` and introduced the new
`InputSpec` type to define
the input shape/type of a model (instead of `LoadOptions`).
* New `ShapeElement`
type to allow for named dynamic dimensions (in `InputSpec`).
* `max.engine.engine` module was renamed to
`max.engine.info`.
#### MAX Engine C API
* [`M_newTorchInputSpec()`](/max/api/c/pytorch/config#m_newtorchinputspec)
now supports named dynamic dimensions (via new `dimNames` argument).
### ❌ Removed
* Removed TensorFlow support in the MAX SDK, so you can no longer load a
TensorFlow SavedModel for inference. However, TensorFlow is still available for
enterprise customers.
We removed TensorFlow because industry-wide TensorFlow usage has declined
significantly, especially for the latest AI innovations. Removing TensorFlow
also cuts our package size by over 50% and accelerates the development of
other customer-requested features. If you have a production use-case for a
TensorFlow model, please [contact
us](https://www.modular.com/request-demo).
* Removed the Python `CommonLoadOptions`, `TorchLoadOptions`, and
`TensorFlowLoadOptions` classes. See note above about
`InferenceSession.load()` changes.
* Removed the Mojo `LoadOptions` type. See the note above about
`InferenceSession.load()` changes.
## v24.2.1 (2024-04-11)
* You can now import more MAX Graph functions from `max.graph.ops` instead of
using `max.graph.ops.elementwise`. For example:
```mojo
from max.graph import ops
var relu = ops.relu(matmul)
```
## v24.2 (2024-03-28)
* MAX Engine now supports TorchScript models with dynamic input shapes.
No matter what the input shapes are, you still need to [specify the input
specs](/max/model-formats#specify-torchscript-input-specs) for all
TorchScript models.
* The Mojo standard library is now open source!
Read more about it in [this blog
post](https://www.modular.com/blog/the-next-big-step-in-mojo-open-source).
* And, of course, lots of Mojo updates, including implicit traits, support for
keyword arguments in Python calls, a new `List` type (previously
`DynamicVector`), some refactoring that might break your code, and much more.
For details, see the [Mojo changelog](/mojo/changelog#v242-2024-03-28).
## v24.1.1 (2024-03-18)
This is a minor release that improves error reports.
## v24.1 (2024-02-29)
The first release of the MAX platform is here! 🚀
This is a **preview version** of the MAX platform. That means it
is not ready for production deployment and designed only for local development
and evaluation.
Because this is a preview, some API libraries are still in development and
subject to change, and some features that we previously announced are not quite
ready yet. But there is a lot that you can do in this release!
This release includes our flagship developer tools, currently for **Linux
only**:
* **MAX Engine**: Our state-of-the-art graph compiler and runtime library that
executes models from PyTorch and ONNX, with incredible inference
speed on a wide range of hardware.
* API libraries in Python, C, and Mojo to run inference with your existing
models. [See the API references](/max/api).
* The `max benchmark` tool, which runs MLPerf
benchmarks on any compatible model without writing any code.
* The `max visualize` tool, which allows you to visualize
your model in Netron after partially lowering in MAX Engine.
* An early look at the [MAX Graph API](/max/model-formats#max-graph), our
low-level library for building high-performance inference graphs.
* **MAX Serving**: A preview of our serving wrapper for MAX Engine that
provides full interoperability with existing AI serving systems (such as
Triton) and that seamlessly deploys within existing container infrastructure
(such as Kubernetes).
* A Docker image that runs MAX Engine as a backend for NVIDIA Triton
Inference Server.
* **Mojo**: The world's first programming language built from the ground-up for AI
developers, with cutting-edge compiler technology that delivers unparalleled
performance and programmability for any hardware.
* The latest version of Mojo, the standard library, and the `mojo` command
line tool. These are always included in MAX, so you don't need to download
any separate packages.
* The Mojo changes in each release are often quite long, so we're going to
continue sharing those in the existing [Mojo changelog](/mojo/changelog).
Additionally, we've started a new [GitHub repo for
MAX](https://github.com/modular/max), where we currently share a bunch of
code examples for our API libraries, including some large model pipelines.
You can also use this repo to [report issues with
MAX](https://github.com/modular/modular/issues/new/choose).
### Model Architecture Support
* Added support for the following model architectures:
* `OlmoForCausalLM` (such as `allenai/OLMo-1B-0724-hf`)
* `GraniteForCausalLM` (such as `ibm-granite/granite-3.1-8b-instruct`)
* `Phi3ForCausalLM` (for Microsoft Phi-3 models)
* `Qwen2ForCausalLM` (such as Qwen2 models)
Example usage:
```sh
max-pipelines generate \
--model-path allenai/OLMo-1B-0724-hf \
--prompt "Write bubble sort in mojo"
```
* The `max.pipelines.dataprocessing.tokenizer` and
`max.pipelines.dataprocessing.gguf_utils` modules have been removed.
* The previously deprecated `PipelineConfig.architecture` field and its
corresponding `--architecture` CLI argument have been removed.
### `max-pipelines` CLI
* The `--devices` CLI argument now supports a comma-separated list of GPU IDs
prefixed with `gpu:` like `--devices=gpu:0,1,2,3`. We no longer support the
previous `--devices=gpu-` format.
```sh
max-pipelines generate --model-path=meta-llama/Llama-3.3-70B-Instruct \
--quantization-encoding bfloat16 \
--devices gpu:0,1,2,3 \
--prompt="Design a self-sustaining colony on Neptune's moon Triton with a myth/science fusion name, three quantum tech breakthroughs, one ethical debate, a neon-lit cultural ritual, and a hidden flaw—presented in bullet points."
```
* Removed `--huggingface-repo-id` PipelineConfig option and CLI argument in favor
of `--model-path`.
* Consolidated `-model-path` and `-weight-path`. If valid `-weight-path`(s) are
provided, they'll now override `--model-path`, which in turn handles both local
and remote (Hugging Face) cases. If we cannot derive the weights from the
`--weight-path`(s), we'll now fall back to the `--model-path`, which has to be set
explicitly by the user.
* Added `--huggingface-revision` option, to allow selecting a non-default branch
or a specific commit in a Hugging Face model repository.
---
## max benchmark
Runs comprehensive benchmark tests on an active model server to measure
performance metrics including throughput, latency, and resource utilization.
For a complete walkthrough, see the tutorial to [benchmark MAX on a
GPU](/max/deploy/benchmark).
Before running this command, make sure the model server is running, via [`max
serve`](/max/cli/serve).
For example, here's how to benchmark the `google/gemma-3-27b-it` model
already running on localhost:
```sh
max benchmark \
--model google/gemma-3-27b-it \
--backend modular \
--endpoint /v1/chat/completions \
--num-prompts 50 \
--dataset-name arxiv-summarization \
--arxiv-summarization-input-len 12000 \
--max-output-len 1200
```
When it's done, you'll see the results printed to the terminal.
By default, it sends inference requests to `localhost:8000`, but you can change
that with the `--host` and `--port` arguments.
If you want to save the results, add the `--save-result` option, which creates
a JSON file in the local path with the following naming convention:
```bash
{backend}-{request_rate}qps-{model_name}-{timestamp}.json
```
But you can specify the file name with `--result-filename` and change the
directory with `--result-dir`.
Instead of passing all these benchmark options, you can instead pass a
configuration file. See [Configuration file](#benchmark-configuration-file)
below.
:::note
The `max benchmark` command is a convenient packaging for our open-source
[`benchmark_serving.py`](https://github.com/modular/modular/tree/main/max/python/max/benchmark#benchmark-max)
script and accepts all the same options.
:::
## Usage
```sh
max benchmark [OPTIONS]
```
## Options
This list of options is not exhaustive. For more information, run `max
benchmark --help` or see the [benchmarking script source
code](https://github.com/modular/modular/tree/main/max/python/max/benchmark).
* Backend configuration:
* `--backend`: Choose from `modular` (MAX `v1/completions` endpoint),
`modular-chat` (MAX `v1/chat/completions` endpoint), or `vllm` (vLLM)
* `--model`: Hugging Face model ID or local path
* Load generation:
* `--num-prompts`: Number of prompts to process (`int`, default: `500`)
* `--request-rate`: Request rate in requests/second (`int`, default: `inf`)
* `--seed`: The random seed used to sample the dataset (`int`, default: `0`)
* Serving options
* `--base-url`: Base URL of the API service
* `--endpoint`: Specific API endpoint (`/v1/completions` or
`/v1/chat/completions`)
* `--tokenizer`: Hugging Face tokenizer to use (can be different from model)
* `--dataset-name`: (Required; default:`sharegpt`) Specifies which type of
benchmark dataset to use. This determines the dataset class and processing
logic. See [Datasets](#datasets) below.
* `--dataset-path`: Path to a local dataset file that overrides the default
dataset source for the specified `dataset-name`. The file format must match
the expected format for the specified `dataset-name` (such as JSON for
`axolotl`, JSONL for `obfuscated-conversations`, plain text for `sonnet`).
* Additional options
* `--collect-gpu-stats`: Report GPU utilization and memory consumption
for both NVIDIA and AMD GPUs. Only works when running `max benchmark`
on the same instance as the server.
* `--save-results`: Saves results to a local JSON file.
* LoRA benchmarking options
The benchmark script supports testing LoRA adapter performance for
supported models and target modules:
* `--num-loras`: Number of LoRA adapters to test. If > 0, test LoRA
adapters will be generated.
* `--lora-rank`: LoRA rank (r parameter) for generated adapters. Controls
the dimension of the low-rank decomposition.
* `--lora-output-dir`: Directory to save generated LoRA adapters.
Defaults to `/tmp/loras`.
* `--lora-paths`: Paths to existing LoRA adapters to use instead of
generating new ones.
* `--lora-request-ratio`: Ratio of requests to send with LoRA adapters
(0.0-1.0). For example, 0.5 means 50% of requests use LoRA.
* `--max-num-loras`: Maximum number of LoRA adapters cached on GPU.
This should match the server configuration.
* `--lora-target-modules`: List of module names to apply LoRA to when
generating random test adapters (e.g., `q_proj`, `k_proj`, `v_proj`,
`o_proj`). Only used when `--num-loras` > 0 and generating adapters
(not when using existing `--lora-paths`).
* `--config-file`: Path to a YAML file containing benchmark configuration.
The configuration file is a YAML file that contains key-value pairs for all
your benchmark configurations (as a replacement for individual command line
options). See [Configuration file](#benchmark-configuration-file) below.
### Datasets
The `--dataset-name` option supports several dataset names/formats you can
use for benchmarking:
* `arxiv-summarization` - Research paper summarization dataset containing
academic papers with abstracts for training summarization models, from Hugging
Face Datasets.
* `axolotl` - Local dataset in Axolotl format with conversation segments
labeled as human/assistant text, from Hugging Face Datasets.
* `code_debug` - Long-context code debugging dataset containing code with
multiple choice debugging questions for testing long-context understanding,
from Hugging Face Datasets.
* `obfuscated-conversations` - Local dataset with obfuscated conversation data.
You must pair this with the `--dataset-path` option to specify the local JSONL
file.
* `random` - Synthetically generated random dataset that creates random
token sequences with configurable input/output lengths and distributions.
* `sharegpt` - Conversational dataset containing human-AI conversations for
chat model evaluation, from Hugging Face Datasets.
* `sonnet` - Poetry dataset using local text files containing poem lines,
from Hugging Face Datasets.
* `vision-arena` - Vision-language benchmark dataset containing images with
associated questions for multimodal model evaluation, from Hugging Face
Datasets.
You can override the default dataset source for any of these using the
`--dataset-path` option (except for generated datasets like `random`), but you
must always specify a `--dataset-name` so the tool knows how to process the
dataset format.
### Configuration file {#benchmark-configuration-file}
The `--config-file` option allows you to specify a YAML file containing all
your benchmark configurations, as a replacement for individual command line
options. Simply define all the configuration options (corresponding to the `max
benchmark` command line options) in a YAML file, all nested under the
`benchmark_config` key.
:::caution
In the YAML file, the properties **must use `snake_case` names** instead of
using the hyphenated names from the command line options. For example,
`--num-prompts` becomes `num_prompts`.
:::
For instance, instead of specifying all configurations in the command line like
this:
```sh
max benchmark \
--model google/gemma-3-27b-it \
--backend modular \
--endpoint /v1/chat/completions \
--host localhost \
--port 8000 \
--num-prompts 50 \
--dataset-name arxiv-summarization \
--arxiv-summarization-input-len 12000 \
--max-output-len 1200
```
Create this configuration file:
```yaml title="gemma-benchmark.yaml"
benchmark_config:
model: google/gemma-3-27b-it
backend: modular
endpoint: /v1/chat/completions
host: localhost
port: 8000
num_prompts: 50
dataset_name: arxiv-summarization
arxiv_summarization_input_len: 12000
max_output_len: 1200
```
And then run the benchmark by passing that file:
```sh
max benchmark --config-file gemma-benchmark.yaml
```
For more config file examples, see our [benchmark configs on
GitHub](https://github.com/modular/modular/tree/main/max/python/max/benchmark/configs).
For a walkthrough of setting up an endpoint and running a benchmark, see the
[quickstart guide](/max/get-started).
## Output
Here's an explanation of the most important metrics printed upon completion:
* **Request throughput**: Number of complete requests processed per second
* **Input token throughput**: Number of input tokens processed per second
* **Output token throughput**: Number of tokens generated per second
* **TTFT**: Time to first token—the time from request start to first
token generation
* **TPOT**: Time per output token—the average time taken to generate
each output token
* **ITL**: Inter-token latency—the average time between consecutive token
or token-chunk generations
If `--collect-gpu-stats` is set, you'll also see these:
* **GPU utilization**: Percentage of time during which at least one GPU kernel
is being executed
* **Peak GPU memory used**: Peak memory usage during benchmark run
---
## max encode
Converts input text into embeddings for semantic search, text similarity, and
NLP applications.
For example:
```bash
max encode \
--model sentence-transformers/all-MiniLM-L6-v2 \
--prompt "Convert this text into embeddings"
```
## Usage
```shell
max encode [OPTIONS]
```
## Options
### `--allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-allow-safetensors-weights-fp32-bf6-bidirectional-cast`
Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models.
### `--cache-strategy `
The cache strategy to use. This defaults to model\_default, which selects the default strategy for the requested architecture. You can also force a specific strategy: continuous or paged.
### `--ce-delay-ms `
Duration of scheduler sleep prior to starting a prefill batch. Experimental for the TTS scheduler.
### `--chat-template `
Optional custom chat template to override the one shipped with the Hugging Face model config. If a path is provided, the file is read during config resolution and the content stored as a string. If None, the model’s default chat template is used.
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--custom-architectures `
Custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name. Each module must expose an ARCHITECTURES list of architectures to register.
### `--data-parallel-degree `
Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type.
### `--defer-resolve, --no-defer-resolve`
Whether to defer resolving the pipeline config.
### `--device-graph-capture, --no-device-graph-capture`
Enable device graph capture/replay for graph execution.
### `--device-memory-utilization `
The fraction of available device memory that the process should consume. This informs the KVCache workspace size: kv\_cache\_workspace = (total\_free\_memory \* device\_memory\_utilization) - model\_weights\_size.
### `--devices `
Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu).
### `--draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast`
Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models.
### `--draft-config-file `
### `--draft-data-parallel-degree `
Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type.
### `--draft-devices `
Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu).
### `--draft-force-download, --no-draft-force-download`
Whether to force download a given file if it’s already present in the local cache.
### `--draft-huggingface-model-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--draft-huggingface-weight-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--draft-model-path `
The repository ID of a Hugging Face model to use. The –model option also works as an alias.
### `--draft-quantization-encoding `
Weight encoding type.
### `--draft-section-name `
### `--draft-served-model-name `
Optional override for client-facing model name. Defaults to model\_path.
### `--draft-trust-remote-code, --no-draft-trust-remote-code`
Whether or not to allow for custom modelling files on Hugging Face.
### `--draft-use-subgraphs, --no-draft-use-subgraphs`
Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true.
### `--draft-vision-config-overrides `
Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}.
### `--draft-weight-path `
Optional path or url of the model weights to use.
### `--enable-chunked-prefill, --no-enable-chunked-prefill`
Enable chunked prefill to split context encoding requests into multiple chunks based on max\_batch\_input\_tokens.
### `--enable-echo, --no-enable-echo`
Whether the model should be built with echo capabilities.
### `--enable-in-flight-batching, --no-enable-in-flight-batching`
When enabled, prioritizes token generation by batching it with context encoding requests.
### `--enable-kvcache-swapping-to-host, --no-enable-kvcache-swapping-to-host`
Whether to swap paged KVCache blocks to host memory when device blocks are evicted.
### `--enable-lora, --no-enable-lora`
Enables LoRA on the server.
### `--enable-min-tokens, --no-enable-min-tokens`
Whether to enable min\_tokens, which blocks the model from generating stopping tokens before the min\_tokens count is reached.
### `--enable-overlap-scheduler, --no-enable-overlap-scheduler`
Whether to enable the overlap scheduler. This feature allows the scheduler to run alongside GPU execution. This helps improve GPU utilization. This is an experimental feature which may crash and burn. This feature will be enabled by default for some selected architectures. You can forcibly disable this by setting –no-enable-overlap-scheduler –force.
### `--enable-penalties, --no-enable-penalties`
Whether to apply frequency and presence penalties to the model’s output.
### `--enable-prefix-caching, --no-enable-prefix-caching`
Whether to enable prefix caching for the paged KVCache.
### `--enable-prioritize-first-decode, --no-enable-prioritize-first-decode`
When enabled, the scheduler always runs a TG batch immediately after a CE batch with the same requests. This may reduce time-to-first-chunk latency. Experimental for the TTS scheduler.
### `--enable-structured-output, --no-enable-structured-output`
Enable structured generation/guided decoding for the server. This allows the user to pass a json schema in the response\_format field, which the LLM will adhere to.
### `--enable-variable-logits, --no-enable-variable-logits`
Enable the sampling graph to accept a ragged tensor of different sequences as inputs, along with their associated logit\_offsets. This is needed to produce additional logits for echo and speculative decoding purposes.
### `--ep-size `
The expert parallelism size. Needs to be 1 (no expert parallelism) or the total number of GPUs across nodes.
### `--execute-empty-batches, --no-execute-empty-batches`
Whether the scheduler should execute empty batches.
### `--force, --no-force`
Skip validation of user provided flags against the architecture’s required arguments.
### `--force-download, --no-force-download`
Whether to force download a given file if it’s already present in the local cache.
### `--gpu-profiling `
Whether to enable GPU profiling of the model.
### `--host-kvcache-swap-space-gb `
The amount of host memory to use for the host KVCache in GiB. This space is only allocated when kvcache\_swapping\_to\_host is enabled.
### `--huggingface-model-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--huggingface-weight-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--kv-cache-format `
Override the default data type for the KV cache.Supported values: float32, bfloat16, float8\_e4m3fn.
### `--kv-cache-page-size `
The number of tokens in a single page in the paged KVCache.
### `--kvcache-ce-watermark `
Projected cache usage threshold for scheduling CE requests, considering current and incoming requests. CE is scheduled if either projected usage stays below this threshold or no active requests exist. Higher values can cause more preemptions.
### `--lora-paths `
List of statically defined LoRA paths.
### `--max-batch-input-tokens `
The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation.
### `--max-batch-size `
Maximum batch size to execute with the model. When not specified (None), this value is determined dynamically. For server launches, set this higher based on server capacity.
### `--max-batch-total-tokens `
Ensures that the sum of the context length in a batch does not exceed max\_batch\_total\_tokens. If None, the sum is not limited.
### `--max-length `
Maximum sequence length of the model.
### `--max-lora-rank `
Maximum rank of all possible LoRAs.
### `--max-num-loras `
The maximum number of active LoRAs in a batch. This controls how many LoRA adapters can be active simultaneously during inference. Lower values reduce memory usage but limit concurrent adapter usage.
### `--max-num-steps `
The number of steps to run for multi-step scheduling. -1 specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models).
### `--max-queue-size-tg `
Maximum number of requests in decode queue. By default, this is max\_batch\_size.
### `--min-batch-size-tg `
Soft floor on the decode batch size. If the TG batch size is larger, the scheduler continues TG batches; if it falls below, the scheduler prioritizes CE. This is not a strict minimum. By default, this is max\_queue\_size\_tg. Experimental for the TTS scheduler.
### `--model-path `
The repository ID of a Hugging Face model to use. The –model option also works as an alias.
### `--num-speculative-tokens `
The number of speculative tokens.
### `--num-warmups `
Number of warmup iterations to run before the final timed run.
**Default:**
`0`
### `--pipeline-role `
Whether the pipeline should serve both a prefill or decode role or both.
### `--pool-embeddings, --no-pool-embeddings`
Whether to pool embedding outputs.
### `--prompt `
The text prompt to use for further generation.
### `--quantization-encoding `
Weight encoding type.
### `--trust-remote-code, --no-trust-remote-code`
Whether or not to allow for custom modelling files on Hugging Face.
### `--use-experimental-kernels `
Enables using experimental mojo kernels with max serve. The kernels could be unstable or incorrect.
### `--use-legacy-module, --no-use-legacy-module`
Whether to use the legacy Module architecture (default=True for backward compatibility). Set to False to use the new Module-based architecture when available.
### `--use-subgraphs, --no-use-subgraphs`
Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true.
### `--use-vendor-blas `
Enables using vendor BLAS libraries (cublas/hipblas/etc) with max serve. Currently, this just replaces matmul calls.
### `--vision-config-overrides `
Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}.
### `--weight-path `
Optional path or url of the model weights to use.
### `--zmq-endpoint-base `
Prefix for ZMQ endpoints used for IPC. This ensures unique endpoints across MAX Serve instances on the same host. Example: lora\_request\_zmq\_endpoint = f”{zmq\_endpoint\_base}-lora\_request”.
---
## max generate
Generates output from a given model and prompt, without using an
endpoint—primarily for debugging and testing purposes.
For example:
```bash
max generate \
--model google/gemma-3-12b-it \
--max-length 1024 \
--max-new-tokens 500 \
--top-k 40 \
--temperature 0.7 \
--seed 42 \
--prompt "Explain quantum computing"
```
:::note
You can adjust parameters like `--max-batch-size` and `--max-length` depending on
your system's available resources such as GPU memory.
:::
For more information on how to use the `generate` command with vision models,
see [Image to text](/max/inference/image-to-text).
## Usage
```shell
max generate [OPTIONS]
```
## Options
### `--allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-allow-safetensors-weights-fp32-bf6-bidirectional-cast`
Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models.
### `--cache-strategy `
The cache strategy to use. This defaults to model\_default, which selects the default strategy for the requested architecture. You can also force a specific strategy: continuous or paged.
### `--ce-delay-ms `
Duration of scheduler sleep prior to starting a prefill batch. Experimental for the TTS scheduler.
### `--chat-template `
Optional custom chat template to override the one shipped with the Hugging Face model config. If a path is provided, the file is read during config resolution and the content stored as a string. If None, the model’s default chat template is used.
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--config-file `
### `--custom-architectures `
Custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name. Each module must expose an ARCHITECTURES list of architectures to register.
### `--data-parallel-degree `
Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type.
### `--defer-resolve, --no-defer-resolve`
Whether to defer resolving the pipeline config.
### `--detokenize, --no-detokenize`
Whether to detokenize the output tokens into text.
### `--device-graph-capture, --no-device-graph-capture`
Enable device graph capture/replay for graph execution.
### `--device-memory-utilization `
The fraction of available device memory that the process should consume. This informs the KVCache workspace size: kv\_cache\_workspace = (total\_free\_memory \* device\_memory\_utilization) - model\_weights\_size.
### `--devices `
Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu).
### `--draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast, --no-draft-allow-safetensors-weights-fp32-bf6-bidirectional-cast`
Whether to allow automatic float32 to/from bfloat16 safetensors weight type casting, if needed. Currently only supported in Llama3 models.
### `--draft-config-file `
### `--draft-data-parallel-degree `
Data-parallelism parameter. The degree to which the model is replicated is dependent on the model type.
### `--draft-devices `
Whether to run the model on CPU (–devices=cpu), GPU (–devices=gpu) or a list of GPUs (–devices=gpu:0,1) etc. An ID value can be provided optionally to indicate the device ID to target. If not provided, the model will run on the first available GPU (–devices=gpu), or CPU if no GPUs are available (–devices=cpu).
### `--draft-force-download, --no-draft-force-download`
Whether to force download a given file if it’s already present in the local cache.
### `--draft-huggingface-model-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--draft-huggingface-weight-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--draft-model-path `
The repository ID of a Hugging Face model to use. The –model option also works as an alias.
### `--draft-quantization-encoding `
Weight encoding type.
### `--draft-section-name `
### `--draft-served-model-name `
Optional override for client-facing model name. Defaults to model\_path.
### `--draft-trust-remote-code, --no-draft-trust-remote-code`
Whether or not to allow for custom modelling files on Hugging Face.
### `--draft-use-subgraphs, --no-draft-use-subgraphs`
Whether to use subgraphs for the model. This can significantly reduce compile time, especially for large models with identical blocks. Default is true.
### `--draft-vision-config-overrides `
Model-specific vision configuration overrides. For example, for InternVL: {“max\_dynamic\_patch”: 24}.
### `--draft-weight-path `
Optional path or url of the model weights to use.
### `--enable-chunked-prefill, --no-enable-chunked-prefill`
Enable chunked prefill to split context encoding requests into multiple chunks based on max\_batch\_input\_tokens.
### `--enable-echo, --no-enable-echo`
Whether the model should be built with echo capabilities.
### `--enable-in-flight-batching, --no-enable-in-flight-batching`
When enabled, prioritizes token generation by batching it with context encoding requests.
### `--enable-kvcache-swapping-to-host, --no-enable-kvcache-swapping-to-host`
Whether to swap paged KVCache blocks to host memory when device blocks are evicted.
### `--enable-lora, --no-enable-lora`
Enables LoRA on the server.
### `--enable-min-tokens, --no-enable-min-tokens`
Whether to enable min\_tokens, which blocks the model from generating stopping tokens before the min\_tokens count is reached.
### `--enable-overlap-scheduler, --no-enable-overlap-scheduler`
Whether to enable the overlap scheduler. This feature allows the scheduler to run alongside GPU execution. This helps improve GPU utilization. This is an experimental feature which may crash and burn. This feature will be enabled by default for some selected architectures. You can forcibly disable this by setting –no-enable-overlap-scheduler –force.
### `--enable-penalties, --no-enable-penalties`
Whether to apply frequency and presence penalties to the model’s output.
### `--enable-prefix-caching, --no-enable-prefix-caching`
Whether to enable prefix caching for the paged KVCache.
### `--enable-prioritize-first-decode, --no-enable-prioritize-first-decode`
When enabled, the scheduler always runs a TG batch immediately after a CE batch with the same requests. This may reduce time-to-first-chunk latency. Experimental for the TTS scheduler.
### `--enable-structured-output, --no-enable-structured-output`
Enable structured generation/guided decoding for the server. This allows the user to pass a json schema in the response\_format field, which the LLM will adhere to.
### `--enable-variable-logits, --no-enable-variable-logits`
Enable the sampling graph to accept a ragged tensor of different sequences as inputs, along with their associated logit\_offsets. This is needed to produce additional logits for echo and speculative decoding purposes.
### `--ep-size `
The expert parallelism size. Needs to be 1 (no expert parallelism) or the total number of GPUs across nodes.
### `--execute-empty-batches, --no-execute-empty-batches`
Whether the scheduler should execute empty batches.
### `--force, --no-force`
Skip validation of user provided flags against the architecture’s required arguments.
### `--force-download, --no-force-download`
Whether to force download a given file if it’s already present in the local cache.
### `--frequency-penalty `
The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text.
### `--gpu-profiling `
Whether to enable GPU profiling of the model.
### `--host-kvcache-swap-space-gb `
The amount of host memory to use for the host KVCache in GiB. This space is only allocated when kvcache\_swapping\_to\_host is enabled.
### `--huggingface-model-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--huggingface-weight-revision `
Branch or Git revision of Hugging Face model repository to use.
### `--ignore-eos`
If True, the response will ignore the EOS token, and continue to generate until the max tokens or a stop string is hit.
### `--image_url `
Images to include along with prompt, specified as URLs. The images are ignored if the model does not support image inputs.
### `--kv-cache-format `
Override the default data type for the KV cache.Supported values: float32, bfloat16, float8\_e4m3fn.
### `--kv-cache-page-size `
The number of tokens in a single page in the paged KVCache.
### `--kvcache-ce-watermark `
Projected cache usage threshold for scheduling CE requests, considering current and incoming requests. CE is scheduled if either projected usage stays below this threshold or no active requests exist. Higher values can cause more preemptions.
### `--lora-paths `
List of statically defined LoRA paths.
### `--max-batch-input-tokens `
The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation.
### `--max-batch-size `
Maximum batch size to execute with the model. When not specified (None), this value is determined dynamically. For server launches, set this higher based on server capacity.
### `--max-batch-total-tokens `
Ensures that the sum of the context length in a batch does not exceed max\_batch\_total\_tokens. If None, the sum is not limited.
### `--max-length `
Maximum sequence length of the model.
### `--max-lora-rank `
Maximum rank of all possible LoRAs.
### `--max-new-tokens `
Maximum number of new tokens to generate during a single inference pass of the model.
### `--max-num-loras `
The maximum number of active LoRAs in a batch. This controls how many LoRA adapters can be active simultaneously during inference. Lower values reduce memory usage but limit concurrent adapter usage.
### `--max-num-steps `
The number of steps to run for multi-step scheduling. -1 specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models).
### `--max-queue-size-tg `
Maximum number of requests in decode queue. By default, this is max\_batch\_size.
### `--min-batch-size-tg `
Soft floor on the decode batch size. If the TG batch size is larger, the scheduler continues TG batches; if it falls below, the scheduler prioritizes CE. This is not a strict minimum. By default, this is max\_queue\_size\_tg. Experimental for the TTS scheduler.
### `--min-new-tokens `
Minimum number of tokens to generate in the response.
### `--min-p `
Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in \[0, 1]. Set to 0 to disable this.
### `--model-path `
The repository ID of a Hugging Face model to use. The –model option also works as an alias.
### `--num-speculative-tokens `
The number of speculative tokens.
### `--num-warmups `
Number of warmup iterations to run before the final timed run.
**Default:**
`0`
### `--pipeline-role `
Whether the pipeline should serve both a prefill or decode role or both.
### `--pool-embeddings, --no-pool-embeddings`
Whether to pool embedding outputs.
### `--presence-penalty `
The presence penalty to apply to the model’s output. A positive value will penalize new tokens that have already appeared in the generated text at least once.
### `--prompt `
The text prompt to use for further generation.
### `--quantization-encoding `
Weight encoding type.
### `--repetition-penalty `
The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in the generated text at least once.
### `--rope-type `
Force using a specific rope type: none, normal, or neox. Only matters for GGUF weights.
### `--stop `
A list of detokenized sequences that can be used as stop criteria when generating a new sequence. Can be specified multiple times.
### `--stop-token-ids `
A list of token ids that are used as stopping criteria when generating a new sequence. Comma-separated integers.
### `--temperature `
Controls the randomness of the model’s output; higher values produce more diverse responses.
### `--top-k `
Limits the sampling to the K most probable tokens. This defaults to 255. For greedy sampling, set to 1.
### `--top-p